SameDiff execution, TF and memory management overhaul (#10)
* SameDiff execution memory management improvements, round 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clear node outputs closed array references; Slight change to OpValidation internals to not rely on cached op outputs Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next step Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * More polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add WeakIdentityHashmap Signed-off-by: AlexDBlack <blacka101@gmail.com> * Session fixes for control ops and next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * First steps for training session + in-line updating Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix losses and history during training Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * BiasAdd and other fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Don't use SDVariable.getArr() in TFGraphTestAllHelper (import tests) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * First steps for new dependency tracking approach Signed-off-by: AlexDBlack <blacka101@gmail.com> * Start integrating dependency tracking for memory management Signed-off-by: AlexDBlack <blacka101@gmail.com> * Non-control op dependency tracking works/passes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Switch/merge Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup and next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue dependency tracking for initial variables/constants Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add check for aliases when determining if safe to close array Signed-off-by: AlexDBlack <blacka101@gmail.com> * First pass on new TF graph import class Signed-off-by: AlexDBlack <blacka101@gmail.com> * Import fixes, op fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup and fixes for new TF import mapper Signed-off-by: AlexDBlack <blacka101@gmail.com> * More cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Partial implementation of new dependency tracker Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * AbstractDependencyTracker for shared code Signed-off-by: AlexDBlack <blacka101@gmail.com> * Overhaul SameDiff graph execution (dependency tracking) Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes, cleanup, next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Ad no-op memory manager, cleanup, fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix switch dependency tracking Signed-off-by: AlexDBlack <blacka101@gmail.com> * INDArray.toString: no exception on closed arrays, just note closed Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix enter and exit dependency tracking Signed-off-by: AlexDBlack <blacka101@gmail.com> * TensorArray memory management fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add unique ID for INDArray instances Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix memory management for NextIteration outputs in multi-iteration loops Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove (now unnecessary) special case handling for nested enters Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Handle control dependencies during execution; javadoc for memory managers Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup, polish, code comments, javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup and more javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add memory validation for all TF import tests - ensure all arrays (except outputs) are released Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up arrays waiting on unexecuted ops at the end of execution Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes for enter op memory managent in the context of multiple non-nested loops/frames Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix order of operation issues for dependency tracker Signed-off-by: AlexDBlack <blacka101@gmail.com> * Always clear op fields after execution to avoid leaks or unintended array reuse Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Re-implement dtype conversion Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for control dependencies execution (dependency tracking) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix TF import overrides and filtering Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for constant enter array dependency tracking Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J Fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * More DL4J fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup and polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * More polish and javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com> * More logging level tweaks, small DL4J fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix to DL4J SameDiffLayer Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix empty array deserialization, add extra deserialization checks Signed-off-by: AlexDBlack <blacka101@gmail.com> * FlatBuffers control dep serialization fixes; test serialization as part of all TF import tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Variable control dependencies serialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue with removing inputs for ops Signed-off-by: AlexDBlack <blacka101@gmail.com> * FlatBuffers NDArray deserialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * FlatBuffers NDArray deserialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Final cleanup/polish Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
f31661e13b
commit
3f0b4a2d4c
|
@ -157,8 +157,8 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() {
|
public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345));
|
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
||||||
|
|
||||||
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
|
@ -194,8 +194,9 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testComputationGraphFrozenLayerParamsAfterBackprop() {
|
public void testComputationGraphFrozenLayerParamsAfterBackprop() {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4,12345), Nd4j.rand(100, 1, 12345));
|
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
||||||
String frozenBranchName = "B1-";
|
String frozenBranchName = "B1-";
|
||||||
String unfrozenBranchName = "B2-";
|
String unfrozenBranchName = "B2-";
|
||||||
|
|
||||||
|
@ -254,43 +255,18 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testFrozenLayerVsSgd() {
|
public void testFrozenLayerVsSgd() {
|
||||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345));
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
||||||
|
|
||||||
MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder()
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.updater(new Sgd(2))
|
.updater(new Sgd(2))
|
||||||
.list()
|
.list()
|
||||||
.layer(0,
|
.layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
|
||||||
new DenseLayer.Builder()
|
.layer(1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build())
|
||||||
.nIn(4)
|
.layer(2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build())
|
||||||
.nOut(3)
|
.layer(3,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build())
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.layer(1,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.updater(new Sgd(0.0))
|
|
||||||
.biasUpdater(new Sgd(0.0))
|
|
||||||
.nIn(3)
|
|
||||||
.nOut(4)
|
|
||||||
.build()
|
|
||||||
).layer(2,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.updater(new Sgd(0.0))
|
|
||||||
.biasUpdater(new Sgd(0.0))
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(2)
|
|
||||||
.build()
|
|
||||||
|
|
||||||
).layer(3,
|
|
||||||
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
|
|
||||||
.updater(new Sgd(0.0))
|
|
||||||
.biasUpdater(new Sgd(0.0))
|
|
||||||
.activation(Activation.TANH)
|
|
||||||
.nIn(2)
|
|
||||||
.nOut(1)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder()
|
||||||
|
@ -298,36 +274,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.updater(new Sgd(2))
|
.updater(new Sgd(2))
|
||||||
.list()
|
.list()
|
||||||
.layer(0,
|
.layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
|
||||||
new DenseLayer.Builder()
|
.layer(1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()))
|
||||||
.nIn(4)
|
.layer(2,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()))
|
||||||
.nOut(3)
|
.layer(3,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.layer(1,
|
|
||||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(3)
|
|
||||||
.nOut(4)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.layer(2,
|
|
||||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(2)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
).layer(3,
|
|
||||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
||||||
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
|
|
||||||
.activation(Activation.TANH)
|
|
||||||
.nIn(2)
|
|
||||||
.nOut(1)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
|
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
|
||||||
frozenNetwork.init();
|
frozenNetwork.init();
|
||||||
|
@ -359,8 +309,8 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testComputationGraphVsSgd() {
|
public void testComputationGraphVsSgd() {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345));
|
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
||||||
String frozenBranchName = "B1-";
|
String frozenBranchName = "B1-";
|
||||||
String unfrozenBranchName = "B2-";
|
String unfrozenBranchName = "B2-";
|
||||||
|
|
||||||
|
@ -381,71 +331,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("input")
|
.addInputs("input")
|
||||||
.addLayer(initialLayer,
|
.addLayer(initialLayer,new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
|
||||||
new DenseLayer.Builder()
|
.addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer)
|
||||||
.nIn(4)
|
.addLayer(frozenBranchFrozenLayer1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||||
.nOut(4)
|
new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
|
||||||
.build(),
|
|
||||||
"input"
|
|
||||||
)
|
|
||||||
.addLayer(frozenBranchUnfrozenLayer0,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(3)
|
|
||||||
.build(),
|
|
||||||
initialLayer
|
|
||||||
)
|
|
||||||
.addLayer(frozenBranchFrozenLayer1,
|
|
||||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(3)
|
|
||||||
.nOut(4)
|
|
||||||
.build()
|
|
||||||
),
|
|
||||||
frozenBranchUnfrozenLayer0
|
|
||||||
)
|
|
||||||
.addLayer(frozenBranchFrozenLayer2,
|
.addLayer(frozenBranchFrozenLayer2,
|
||||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||||
new DenseLayer.Builder()
|
new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
|
||||||
.nIn(4)
|
.addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
|
||||||
.nOut(2)
|
.addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
|
||||||
.build()
|
.addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
|
||||||
),
|
.addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
||||||
frozenBranchFrozenLayer1
|
.addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||||
)
|
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
|
||||||
.addLayer(unfrozenLayer0,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(4)
|
|
||||||
.build(),
|
|
||||||
initialLayer
|
|
||||||
)
|
|
||||||
.addLayer(unfrozenLayer1,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(2)
|
|
||||||
.build(),
|
|
||||||
unfrozenLayer0
|
|
||||||
)
|
|
||||||
.addLayer(unfrozenBranch2,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(2)
|
|
||||||
.nOut(1)
|
|
||||||
.build(),
|
|
||||||
unfrozenLayer1
|
|
||||||
)
|
|
||||||
.addVertex("merge",
|
|
||||||
new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
|
||||||
.addLayer(frozenBranchOutput,
|
|
||||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
||||||
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
|
|
||||||
.activation(Activation.TANH)
|
|
||||||
.nIn(3)
|
|
||||||
.nOut(1)
|
|
||||||
.build()
|
|
||||||
),
|
|
||||||
"merge"
|
|
||||||
)
|
|
||||||
.setOutputs(frozenBranchOutput)
|
.setOutputs(frozenBranchOutput)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -454,73 +352,15 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("input")
|
.addInputs("input")
|
||||||
.addLayer(initialLayer,
|
.addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
|
||||||
new DenseLayer.Builder()
|
.addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer)
|
||||||
.nIn(4)
|
.addLayer(frozenBranchFrozenLayer1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build(),frozenBranchUnfrozenLayer0)
|
||||||
.nOut(4)
|
.addLayer(frozenBranchFrozenLayer2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build(),frozenBranchFrozenLayer1)
|
||||||
.build(),
|
.addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
|
||||||
"input"
|
.addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
|
||||||
)
|
.addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
|
||||||
.addLayer(frozenBranchUnfrozenLayer0,
|
.addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
||||||
new DenseLayer.Builder()
|
.addLayer(frozenBranchOutput,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(3).nOut(1).build(),"merge")
|
||||||
.nIn(4)
|
|
||||||
.nOut(3)
|
|
||||||
.build(),
|
|
||||||
initialLayer
|
|
||||||
)
|
|
||||||
.addLayer(frozenBranchFrozenLayer1,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.updater(new Sgd(0.0))
|
|
||||||
.biasUpdater(new Sgd(0.0))
|
|
||||||
.nIn(3)
|
|
||||||
.nOut(4)
|
|
||||||
.build(),
|
|
||||||
frozenBranchUnfrozenLayer0
|
|
||||||
)
|
|
||||||
.addLayer(frozenBranchFrozenLayer2,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.updater(new Sgd(0.0))
|
|
||||||
.biasUpdater(new Sgd(0.0))
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(2)
|
|
||||||
.build()
|
|
||||||
,
|
|
||||||
frozenBranchFrozenLayer1
|
|
||||||
)
|
|
||||||
.addLayer(unfrozenLayer0,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(4)
|
|
||||||
.build(),
|
|
||||||
initialLayer
|
|
||||||
)
|
|
||||||
.addLayer(unfrozenLayer1,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(2)
|
|
||||||
.build(),
|
|
||||||
unfrozenLayer0
|
|
||||||
)
|
|
||||||
.addLayer(unfrozenBranch2,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(2)
|
|
||||||
.nOut(1)
|
|
||||||
.build(),
|
|
||||||
unfrozenLayer1
|
|
||||||
)
|
|
||||||
.addVertex("merge",
|
|
||||||
new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
|
||||||
.addLayer(frozenBranchOutput,
|
|
||||||
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
|
|
||||||
.updater(new Sgd(0.0))
|
|
||||||
.biasUpdater(new Sgd(0.0))
|
|
||||||
.activation(Activation.TANH)
|
|
||||||
.nIn(3)
|
|
||||||
.nOut(1)
|
|
||||||
.build()
|
|
||||||
,
|
|
||||||
"merge"
|
|
||||||
)
|
|
||||||
.setOutputs(frozenBranchOutput)
|
.setOutputs(frozenBranchOutput)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -172,8 +172,8 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
||||||
Map<String,INDArray> placeholders = new HashMap<>();
|
Map<String,INDArray> placeholders = new HashMap<>();
|
||||||
placeholders.put("input", f);
|
placeholders.put("input", f);
|
||||||
placeholders.put("label", l);
|
placeholders.put("label", l);
|
||||||
sd.exec(placeholders, lossMse.getVarName());
|
Map<String,INDArray> map = sd.output(placeholders, lossMse.getVarName(), a1.getVarName());
|
||||||
INDArray outSd = a1.getArr();
|
INDArray outSd = map.get(a1.getVarName());
|
||||||
INDArray outDl4j = net.output(f);
|
INDArray outDl4j = net.output(f);
|
||||||
|
|
||||||
assertEquals(testName, outDl4j, outSd);
|
assertEquals(testName, outDl4j, outSd);
|
||||||
|
@ -187,7 +187,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
||||||
|
|
||||||
//Check score
|
//Check score
|
||||||
double scoreDl4j = net.score();
|
double scoreDl4j = net.score();
|
||||||
double scoreSd = lossMse.getArr().getDouble(0) + sd.calcRegularizationScore();
|
double scoreSd = map.get(lossMse.getVarName()).getDouble(0) + sd.calcRegularizationScore();
|
||||||
assertEquals(testName, scoreDl4j, scoreSd, 1e-6);
|
assertEquals(testName, scoreDl4j, scoreSd, 1e-6);
|
||||||
|
|
||||||
double lossRegScoreSD = sd.calcRegularizationScore();
|
double lossRegScoreSD = sd.calcRegularizationScore();
|
||||||
|
|
|
@ -145,7 +145,7 @@ public class LocallyConnected1D extends SameDiffLayer {
|
||||||
val weightsShape = new long[] {outputSize, featureDim, nOut};
|
val weightsShape = new long[] {outputSize, featureDim, nOut};
|
||||||
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
|
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
|
||||||
if (hasBias) {
|
if (hasBias) {
|
||||||
val biasShape = new long[] {1, nOut};
|
val biasShape = new long[] {nOut};
|
||||||
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
|
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -200,7 +200,7 @@ public class LocallyConnected1D extends SameDiffLayer {
|
||||||
|
|
||||||
if (hasBias) {
|
if (hasBias) {
|
||||||
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
||||||
SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b);
|
SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b, true);
|
||||||
return activation.asSameDiff("out", sameDiff, biasAddedResult);
|
return activation.asSameDiff("out", sameDiff, biasAddedResult);
|
||||||
} else {
|
} else {
|
||||||
return activation.asSameDiff("out", sameDiff, result);
|
return activation.asSameDiff("out", sameDiff, result);
|
||||||
|
|
|
@ -145,7 +145,7 @@ public class LocallyConnected2D extends SameDiffLayer {
|
||||||
val weightsShape = new long[] {outputSize[0] * outputSize[1], featureDim, nOut};
|
val weightsShape = new long[] {outputSize[0] * outputSize[1], featureDim, nOut};
|
||||||
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
|
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
|
||||||
if (hasBias) {
|
if (hasBias) {
|
||||||
val biasShape = new long[] {1, nOut};
|
val biasShape = new long[] {nOut};
|
||||||
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
|
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -211,7 +211,7 @@ public class LocallyConnected2D extends SameDiffLayer {
|
||||||
|
|
||||||
if (hasBias) {
|
if (hasBias) {
|
||||||
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
||||||
SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b);
|
SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b, true);
|
||||||
return activation.asSameDiff("out", sameDiff, biasAddedResult);
|
return activation.asSameDiff("out", sameDiff, biasAddedResult);
|
||||||
} else {
|
} else {
|
||||||
return activation.asSameDiff("out", sameDiff, permutedResult);
|
return activation.asSameDiff("out", sameDiff, permutedResult);
|
||||||
|
|
|
@ -114,7 +114,7 @@ public class MergeVertex extends BaseGraphVertex {
|
||||||
}
|
}
|
||||||
|
|
||||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){
|
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){
|
||||||
return Nd4j.hstack(in);
|
return Nd4j.concat(1, in);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -134,6 +134,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
Gradient g = new DefaultGradient();
|
Gradient g = new DefaultGradient();
|
||||||
|
|
||||||
INDArray[] dLdIns;
|
INDArray[] dLdIns;
|
||||||
|
boolean[] noClose = new boolean[getNumInputArrays()];
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||||
if(sameDiff == null){
|
if(sameDiff == null){
|
||||||
doInit();
|
doInit();
|
||||||
|
@ -167,20 +168,21 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
|
|
||||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||||
//TODO Find a more efficient solution for this
|
//TODO Find a more efficient solution for this
|
||||||
|
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
|
||||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||||
INDArray arr = e.getValue();
|
INDArray arr = e.getValue();
|
||||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||||
}
|
}
|
||||||
|
|
||||||
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
|
required.addAll(paramTable.keySet());
|
||||||
for(String s : inputNames){
|
required.addAll(inputNames);
|
||||||
required.add(sameDiff.getVariable(s).gradient().getVarName());
|
|
||||||
}
|
Map<String,INDArray> gradsMap = sameDiff.calculateGradients(phMap, required);
|
||||||
sameDiff.execBackwards(phMap, required);
|
|
||||||
for(String s : paramTable.keySet() ){
|
for(String s : paramTable.keySet() ){
|
||||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
INDArray sdGrad = gradsMap.get(s);
|
||||||
INDArray dl4jGrad = gradTable.get(s);
|
INDArray dl4jGrad = gradTable.get(s);
|
||||||
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
|
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
|
||||||
|
sdGrad.close(); //TODO optimize this
|
||||||
g.gradientForVariable().put(s, dl4jGrad);
|
g.gradientForVariable().put(s, dl4jGrad);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,13 +197,18 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
|
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
|
||||||
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
|
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
|
||||||
dLdIns[j] = epsilon;
|
dLdIns[j] = epsilon;
|
||||||
|
noClose[j] = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO optimize
|
//TODO optimize
|
||||||
for( int i=0; i<dLdIns.length; i++ ){
|
for( int i=0; i<dLdIns.length; i++ ){
|
||||||
|
INDArray before = dLdIns[i];
|
||||||
dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]);
|
dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]);
|
||||||
|
if(!noClose[i]){
|
||||||
|
before.close();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
|
|
|
@ -110,7 +110,13 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
sameDiff.clearPlaceholders(true);
|
sameDiff.clearPlaceholders(true);
|
||||||
sameDiff.clearOpInputs();
|
sameDiff.clearOpInputs();
|
||||||
|
|
||||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
INDArray ret = workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
||||||
|
if(!result.isAttached() && result.closeable()) {
|
||||||
|
//May be attached in rare edge case - for identity, or if gradients are passed through from output to input
|
||||||
|
// unchaned, as in identity, add scalar, etc
|
||||||
|
result.close();
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,6 +128,7 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
Gradient g = new DefaultGradient();
|
Gradient g = new DefaultGradient();
|
||||||
|
|
||||||
INDArray dLdIn;
|
INDArray dLdIn;
|
||||||
|
boolean noCloseEps = false;
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||||
if(sameDiff == null){
|
if(sameDiff == null){
|
||||||
doInit();
|
doInit();
|
||||||
|
@ -151,26 +158,25 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
}
|
}
|
||||||
|
|
||||||
List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1);
|
List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1);
|
||||||
requiredGrads.add(sameDiff.grad(INPUT_KEY).getVarName());
|
requiredGrads.add(INPUT_KEY);
|
||||||
for(String s : paramTable.keySet()){
|
requiredGrads.addAll(paramTable.keySet());
|
||||||
requiredGrads.add(sameDiff.grad(s).getVarName());
|
|
||||||
}
|
|
||||||
|
|
||||||
sameDiff.execBackwards(phMap, requiredGrads);
|
Map<String,INDArray> m = sameDiff.calculateGradients(phMap, requiredGrads);
|
||||||
for(String s : paramTable.keySet() ){
|
for(String s : paramTable.keySet() ){
|
||||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
INDArray sdGrad = m.get(s);
|
||||||
INDArray dl4jGrad = gradTable.get(s);
|
INDArray dl4jGrad = gradTable.get(s);
|
||||||
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
|
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
|
||||||
g.gradientForVariable().put(s, dl4jGrad);
|
g.gradientForVariable().put(s, dl4jGrad);
|
||||||
|
sdGrad.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
SDVariable v = sameDiff.grad(INPUT_KEY);
|
dLdIn = m.get(INPUT_KEY);
|
||||||
dLdIn = v.getArr();
|
|
||||||
|
|
||||||
if(dLdIn == null && fn.getGradPlaceholderName().equals(v.getVarName())){
|
if(dLdIn == null && fn.getGradPlaceholderName().equals(INPUT_KEY)){
|
||||||
//Edge case with lambda layers like identity: SameDiff doesn't store the placeholders
|
//Edge case with lambda layers like identity: SameDiff doesn't store the placeholders
|
||||||
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
|
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
|
||||||
dLdIn = epsilon;
|
dLdIn = epsilon;
|
||||||
|
noCloseEps = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,7 +184,12 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
sameDiff.clearPlaceholders(true);
|
sameDiff.clearPlaceholders(true);
|
||||||
sameDiff.clearOpInputs();
|
sameDiff.clearOpInputs();
|
||||||
|
|
||||||
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
Pair<Gradient, INDArray> ret = new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
||||||
|
if(!noCloseEps && !dLdIn.isAttached() && dLdIn.closeable()) {
|
||||||
|
//Edge case: identity etc - might just pass gradient array through unchanged
|
||||||
|
dLdIn.close();
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**Returns the parameters of the neural network as a flattened row vector
|
/**Returns the parameters of the neural network as a flattened row vector
|
||||||
|
|
|
@ -106,6 +106,12 @@ public struct FlatNode : IFlatbufferObject
|
||||||
#endif
|
#endif
|
||||||
public DType[] GetOutputTypesArray() { return __p.__vector_as_array<DType>(38); }
|
public DType[] GetOutputTypesArray() { return __p.__vector_as_array<DType>(38); }
|
||||||
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
|
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
|
||||||
|
public string ControlDeps(int j) { int o = __p.__offset(42); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||||
|
public int ControlDepsLength { get { int o = __p.__offset(42); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
|
public string VarControlDeps(int j) { int o = __p.__offset(44); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||||
|
public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
|
public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||||
|
public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
|
|
||||||
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
||||||
int id = 0,
|
int id = 0,
|
||||||
|
@ -126,9 +132,15 @@ public struct FlatNode : IFlatbufferObject
|
||||||
VectorOffset outputNamesOffset = default(VectorOffset),
|
VectorOffset outputNamesOffset = default(VectorOffset),
|
||||||
StringOffset opNameOffset = default(StringOffset),
|
StringOffset opNameOffset = default(StringOffset),
|
||||||
VectorOffset outputTypesOffset = default(VectorOffset),
|
VectorOffset outputTypesOffset = default(VectorOffset),
|
||||||
Offset<FlatArray> scalarOffset = default(Offset<FlatArray>)) {
|
Offset<FlatArray> scalarOffset = default(Offset<FlatArray>),
|
||||||
builder.StartObject(19);
|
VectorOffset controlDepsOffset = default(VectorOffset),
|
||||||
|
VectorOffset varControlDepsOffset = default(VectorOffset),
|
||||||
|
VectorOffset controlDepForOffset = default(VectorOffset)) {
|
||||||
|
builder.StartObject(22);
|
||||||
FlatNode.AddOpNum(builder, opNum);
|
FlatNode.AddOpNum(builder, opNum);
|
||||||
|
FlatNode.AddControlDepFor(builder, controlDepForOffset);
|
||||||
|
FlatNode.AddVarControlDeps(builder, varControlDepsOffset);
|
||||||
|
FlatNode.AddControlDeps(builder, controlDepsOffset);
|
||||||
FlatNode.AddScalar(builder, scalarOffset);
|
FlatNode.AddScalar(builder, scalarOffset);
|
||||||
FlatNode.AddOutputTypes(builder, outputTypesOffset);
|
FlatNode.AddOutputTypes(builder, outputTypesOffset);
|
||||||
FlatNode.AddOpName(builder, opNameOffset);
|
FlatNode.AddOpName(builder, opNameOffset);
|
||||||
|
@ -150,7 +162,7 @@ public struct FlatNode : IFlatbufferObject
|
||||||
return FlatNode.EndFlatNode(builder);
|
return FlatNode.EndFlatNode(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(19); }
|
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(22); }
|
||||||
public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); }
|
public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); }
|
||||||
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
||||||
public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); }
|
public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); }
|
||||||
|
@ -200,6 +212,18 @@ public struct FlatNode : IFlatbufferObject
|
||||||
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||||
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
||||||
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
|
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
|
||||||
|
public static void AddControlDeps(FlatBufferBuilder builder, VectorOffset controlDepsOffset) { builder.AddOffset(19, controlDepsOffset.Value, 0); }
|
||||||
|
public static VectorOffset CreateControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||||
|
public static VectorOffset CreateControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||||
|
public static void StartControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||||
|
public static void AddVarControlDeps(FlatBufferBuilder builder, VectorOffset varControlDepsOffset) { builder.AddOffset(20, varControlDepsOffset.Value, 0); }
|
||||||
|
public static VectorOffset CreateVarControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||||
|
public static VectorOffset CreateVarControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||||
|
public static void StartVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||||
|
public static void AddControlDepFor(FlatBufferBuilder builder, VectorOffset controlDepForOffset) { builder.AddOffset(21, controlDepForOffset.Value, 0); }
|
||||||
|
public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||||
|
public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||||
|
public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||||
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
||||||
int o = builder.EndObject();
|
int o = builder.EndObject();
|
||||||
return new Offset<FlatNode>(o);
|
return new Offset<FlatNode>(o);
|
||||||
|
|
|
@ -66,6 +66,12 @@ public final class FlatNode extends Table {
|
||||||
public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); }
|
public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); }
|
||||||
public FlatArray scalar() { return scalar(new FlatArray()); }
|
public FlatArray scalar() { return scalar(new FlatArray()); }
|
||||||
public FlatArray scalar(FlatArray obj) { int o = __offset(40); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
|
public FlatArray scalar(FlatArray obj) { int o = __offset(40); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
|
||||||
|
public String controlDeps(int j) { int o = __offset(42); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
|
public int controlDepsLength() { int o = __offset(42); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public String varControlDeps(int j) { int o = __offset(44); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
|
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
|
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
|
||||||
public static int createFlatNode(FlatBufferBuilder builder,
|
public static int createFlatNode(FlatBufferBuilder builder,
|
||||||
int id,
|
int id,
|
||||||
|
@ -86,9 +92,15 @@ public final class FlatNode extends Table {
|
||||||
int outputNamesOffset,
|
int outputNamesOffset,
|
||||||
int opNameOffset,
|
int opNameOffset,
|
||||||
int outputTypesOffset,
|
int outputTypesOffset,
|
||||||
int scalarOffset) {
|
int scalarOffset,
|
||||||
builder.startObject(19);
|
int controlDepsOffset,
|
||||||
|
int varControlDepsOffset,
|
||||||
|
int controlDepForOffset) {
|
||||||
|
builder.startObject(22);
|
||||||
FlatNode.addOpNum(builder, opNum);
|
FlatNode.addOpNum(builder, opNum);
|
||||||
|
FlatNode.addControlDepFor(builder, controlDepForOffset);
|
||||||
|
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
|
||||||
|
FlatNode.addControlDeps(builder, controlDepsOffset);
|
||||||
FlatNode.addScalar(builder, scalarOffset);
|
FlatNode.addScalar(builder, scalarOffset);
|
||||||
FlatNode.addOutputTypes(builder, outputTypesOffset);
|
FlatNode.addOutputTypes(builder, outputTypesOffset);
|
||||||
FlatNode.addOpName(builder, opNameOffset);
|
FlatNode.addOpName(builder, opNameOffset);
|
||||||
|
@ -110,7 +122,7 @@ public final class FlatNode extends Table {
|
||||||
return FlatNode.endFlatNode(builder);
|
return FlatNode.endFlatNode(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(19); }
|
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); }
|
||||||
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
||||||
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
||||||
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
||||||
|
@ -150,6 +162,15 @@ public final class FlatNode extends Table {
|
||||||
public static int createOutputTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
|
public static int createOutputTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
|
||||||
public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
|
public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
|
||||||
public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); }
|
public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); }
|
||||||
|
public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(19, controlDepsOffset, 0); }
|
||||||
|
public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
|
public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
|
public static void addVarControlDeps(FlatBufferBuilder builder, int varControlDepsOffset) { builder.addOffset(20, varControlDepsOffset, 0); }
|
||||||
|
public static int createVarControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
|
public static void startVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
|
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
|
||||||
|
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
|
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
public static int endFlatNode(FlatBufferBuilder builder) {
|
public static int endFlatNode(FlatBufferBuilder builder) {
|
||||||
int o = builder.endObject();
|
int o = builder.endObject();
|
||||||
return o;
|
return o;
|
||||||
|
|
|
@ -294,7 +294,52 @@ class FlatNode(object):
|
||||||
return obj
|
return obj
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def FlatNodeStart(builder): builder.StartObject(19)
|
# FlatNode
|
||||||
|
def ControlDeps(self, j):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(42))
|
||||||
|
if o != 0:
|
||||||
|
a = self._tab.Vector(o)
|
||||||
|
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# FlatNode
|
||||||
|
def ControlDepsLength(self):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(42))
|
||||||
|
if o != 0:
|
||||||
|
return self._tab.VectorLen(o)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# FlatNode
|
||||||
|
def VarControlDeps(self, j):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(44))
|
||||||
|
if o != 0:
|
||||||
|
a = self._tab.Vector(o)
|
||||||
|
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# FlatNode
|
||||||
|
def VarControlDepsLength(self):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(44))
|
||||||
|
if o != 0:
|
||||||
|
return self._tab.VectorLen(o)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# FlatNode
|
||||||
|
def ControlDepFor(self, j):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(46))
|
||||||
|
if o != 0:
|
||||||
|
a = self._tab.Vector(o)
|
||||||
|
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# FlatNode
|
||||||
|
def ControlDepForLength(self):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(46))
|
||||||
|
if o != 0:
|
||||||
|
return self._tab.VectorLen(o)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def FlatNodeStart(builder): builder.StartObject(22)
|
||||||
def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0)
|
def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0)
|
||||||
def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||||
def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0)
|
def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0)
|
||||||
|
@ -324,4 +369,10 @@ def FlatNodeAddOpName(builder, opName): builder.PrependUOffsetTRelativeSlot(16,
|
||||||
def FlatNodeAddOutputTypes(builder, outputTypes): builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(outputTypes), 0)
|
def FlatNodeAddOutputTypes(builder, outputTypes): builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(outputTypes), 0)
|
||||||
def FlatNodeStartOutputTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1)
|
def FlatNodeStartOutputTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1)
|
||||||
def FlatNodeAddScalar(builder, scalar): builder.PrependUOffsetTRelativeSlot(18, flatbuffers.number_types.UOffsetTFlags.py_type(scalar), 0)
|
def FlatNodeAddScalar(builder, scalar): builder.PrependUOffsetTRelativeSlot(18, flatbuffers.number_types.UOffsetTFlags.py_type(scalar), 0)
|
||||||
|
def FlatNodeAddControlDeps(builder, controlDeps): builder.PrependUOffsetTRelativeSlot(19, flatbuffers.number_types.UOffsetTFlags.py_type(controlDeps), 0)
|
||||||
|
def FlatNodeStartControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||||
|
def FlatNodeAddVarControlDeps(builder, varControlDeps): builder.PrependUOffsetTRelativeSlot(20, flatbuffers.number_types.UOffsetTFlags.py_type(varControlDeps), 0)
|
||||||
|
def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||||
|
def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0)
|
||||||
|
def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||||
def FlatNodeEnd(builder): return builder.EndObject()
|
def FlatNodeEnd(builder): return builder.EndObject()
|
||||||
|
|
|
@ -37,6 +37,12 @@ public struct FlatVariable : IFlatbufferObject
|
||||||
public FlatArray? Ndarray { get { int o = __p.__offset(12); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
|
public FlatArray? Ndarray { get { int o = __p.__offset(12); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
|
||||||
public int Device { get { int o = __p.__offset(14); return o != 0 ? __p.bb.GetInt(o + __p.bb_pos) : (int)0; } }
|
public int Device { get { int o = __p.__offset(14); return o != 0 ? __p.bb.GetInt(o + __p.bb_pos) : (int)0; } }
|
||||||
public VarType Variabletype { get { int o = __p.__offset(16); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } }
|
public VarType Variabletype { get { int o = __p.__offset(16); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } }
|
||||||
|
public string ControlDeps(int j) { int o = __p.__offset(18); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||||
|
public int ControlDepsLength { get { int o = __p.__offset(18); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
|
public string ControlDepForOp(int j) { int o = __p.__offset(20); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||||
|
public int ControlDepForOpLength { get { int o = __p.__offset(20); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
|
public string ControlDepsForVar(int j) { int o = __p.__offset(22); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||||
|
public int ControlDepsForVarLength { get { int o = __p.__offset(22); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||||
|
|
||||||
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
|
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
|
||||||
Offset<IntPair> idOffset = default(Offset<IntPair>),
|
Offset<IntPair> idOffset = default(Offset<IntPair>),
|
||||||
|
@ -45,8 +51,14 @@ public struct FlatVariable : IFlatbufferObject
|
||||||
VectorOffset shapeOffset = default(VectorOffset),
|
VectorOffset shapeOffset = default(VectorOffset),
|
||||||
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
|
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
|
||||||
int device = 0,
|
int device = 0,
|
||||||
VarType variabletype = VarType.VARIABLE) {
|
VarType variabletype = VarType.VARIABLE,
|
||||||
builder.StartObject(7);
|
VectorOffset controlDepsOffset = default(VectorOffset),
|
||||||
|
VectorOffset controlDepForOpOffset = default(VectorOffset),
|
||||||
|
VectorOffset controlDepsForVarOffset = default(VectorOffset)) {
|
||||||
|
builder.StartObject(10);
|
||||||
|
FlatVariable.AddControlDepsForVar(builder, controlDepsForVarOffset);
|
||||||
|
FlatVariable.AddControlDepForOp(builder, controlDepForOpOffset);
|
||||||
|
FlatVariable.AddControlDeps(builder, controlDepsOffset);
|
||||||
FlatVariable.AddDevice(builder, device);
|
FlatVariable.AddDevice(builder, device);
|
||||||
FlatVariable.AddNdarray(builder, ndarrayOffset);
|
FlatVariable.AddNdarray(builder, ndarrayOffset);
|
||||||
FlatVariable.AddShape(builder, shapeOffset);
|
FlatVariable.AddShape(builder, shapeOffset);
|
||||||
|
@ -57,7 +69,7 @@ public struct FlatVariable : IFlatbufferObject
|
||||||
return FlatVariable.EndFlatVariable(builder);
|
return FlatVariable.EndFlatVariable(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); }
|
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(10); }
|
||||||
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
|
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
|
||||||
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
||||||
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||||
|
@ -68,6 +80,18 @@ public struct FlatVariable : IFlatbufferObject
|
||||||
public static void AddNdarray(FlatBufferBuilder builder, Offset<FlatArray> ndarrayOffset) { builder.AddOffset(4, ndarrayOffset.Value, 0); }
|
public static void AddNdarray(FlatBufferBuilder builder, Offset<FlatArray> ndarrayOffset) { builder.AddOffset(4, ndarrayOffset.Value, 0); }
|
||||||
public static void AddDevice(FlatBufferBuilder builder, int device) { builder.AddInt(5, device, 0); }
|
public static void AddDevice(FlatBufferBuilder builder, int device) { builder.AddInt(5, device, 0); }
|
||||||
public static void AddVariabletype(FlatBufferBuilder builder, VarType variabletype) { builder.AddSbyte(6, (sbyte)variabletype, 0); }
|
public static void AddVariabletype(FlatBufferBuilder builder, VarType variabletype) { builder.AddSbyte(6, (sbyte)variabletype, 0); }
|
||||||
|
public static void AddControlDeps(FlatBufferBuilder builder, VectorOffset controlDepsOffset) { builder.AddOffset(7, controlDepsOffset.Value, 0); }
|
||||||
|
public static VectorOffset CreateControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||||
|
public static VectorOffset CreateControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||||
|
public static void StartControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||||
|
public static void AddControlDepForOp(FlatBufferBuilder builder, VectorOffset controlDepForOpOffset) { builder.AddOffset(8, controlDepForOpOffset.Value, 0); }
|
||||||
|
public static VectorOffset CreateControlDepForOpVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||||
|
public static VectorOffset CreateControlDepForOpVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||||
|
public static void StartControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||||
|
public static void AddControlDepsForVar(FlatBufferBuilder builder, VectorOffset controlDepsForVarOffset) { builder.AddOffset(9, controlDepsForVarOffset.Value, 0); }
|
||||||
|
public static VectorOffset CreateControlDepsForVarVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||||
|
public static VectorOffset CreateControlDepsForVarVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||||
|
public static void StartControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||||
public static Offset<FlatVariable> EndFlatVariable(FlatBufferBuilder builder) {
|
public static Offset<FlatVariable> EndFlatVariable(FlatBufferBuilder builder) {
|
||||||
int o = builder.EndObject();
|
int o = builder.EndObject();
|
||||||
return new Offset<FlatVariable>(o);
|
return new Offset<FlatVariable>(o);
|
||||||
|
|
|
@ -28,6 +28,12 @@ public final class FlatVariable extends Table {
|
||||||
public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
|
public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
|
||||||
public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
|
public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
|
||||||
public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; }
|
public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; }
|
||||||
|
public String controlDeps(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
|
public int controlDepsLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public String controlDepForOp(int j) { int o = __offset(20); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
|
public int controlDepForOpLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public String controlDepsForVar(int j) { int o = __offset(22); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
|
public int controlDepsForVarLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
|
||||||
public static int createFlatVariable(FlatBufferBuilder builder,
|
public static int createFlatVariable(FlatBufferBuilder builder,
|
||||||
int idOffset,
|
int idOffset,
|
||||||
|
@ -36,8 +42,14 @@ public final class FlatVariable extends Table {
|
||||||
int shapeOffset,
|
int shapeOffset,
|
||||||
int ndarrayOffset,
|
int ndarrayOffset,
|
||||||
int device,
|
int device,
|
||||||
byte variabletype) {
|
byte variabletype,
|
||||||
builder.startObject(7);
|
int controlDepsOffset,
|
||||||
|
int controlDepForOpOffset,
|
||||||
|
int controlDepsForVarOffset) {
|
||||||
|
builder.startObject(10);
|
||||||
|
FlatVariable.addControlDepsForVar(builder, controlDepsForVarOffset);
|
||||||
|
FlatVariable.addControlDepForOp(builder, controlDepForOpOffset);
|
||||||
|
FlatVariable.addControlDeps(builder, controlDepsOffset);
|
||||||
FlatVariable.addDevice(builder, device);
|
FlatVariable.addDevice(builder, device);
|
||||||
FlatVariable.addNdarray(builder, ndarrayOffset);
|
FlatVariable.addNdarray(builder, ndarrayOffset);
|
||||||
FlatVariable.addShape(builder, shapeOffset);
|
FlatVariable.addShape(builder, shapeOffset);
|
||||||
|
@ -48,7 +60,7 @@ public final class FlatVariable extends Table {
|
||||||
return FlatVariable.endFlatVariable(builder);
|
return FlatVariable.endFlatVariable(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(7); }
|
public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(10); }
|
||||||
public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); }
|
public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); }
|
||||||
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
||||||
public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); }
|
public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); }
|
||||||
|
@ -58,6 +70,15 @@ public final class FlatVariable extends Table {
|
||||||
public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); }
|
public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); }
|
||||||
public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); }
|
public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); }
|
||||||
public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); }
|
public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); }
|
||||||
|
public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(7, controlDepsOffset, 0); }
|
||||||
|
public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
|
public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
|
public static void addControlDepForOp(FlatBufferBuilder builder, int controlDepForOpOffset) { builder.addOffset(8, controlDepForOpOffset, 0); }
|
||||||
|
public static int createControlDepForOpVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
|
public static void startControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
|
public static void addControlDepsForVar(FlatBufferBuilder builder, int controlDepsForVarOffset) { builder.addOffset(9, controlDepsForVarOffset, 0); }
|
||||||
|
public static int createControlDepsForVarVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
|
public static void startControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
public static int endFlatVariable(FlatBufferBuilder builder) {
|
public static int endFlatVariable(FlatBufferBuilder builder) {
|
||||||
int o = builder.endObject();
|
int o = builder.endObject();
|
||||||
return o;
|
return o;
|
||||||
|
|
|
@ -90,7 +90,52 @@ class FlatVariable(object):
|
||||||
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
|
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def FlatVariableStart(builder): builder.StartObject(7)
|
# FlatVariable
|
||||||
|
def ControlDeps(self, j):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
|
||||||
|
if o != 0:
|
||||||
|
a = self._tab.Vector(o)
|
||||||
|
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# FlatVariable
|
||||||
|
def ControlDepsLength(self):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
|
||||||
|
if o != 0:
|
||||||
|
return self._tab.VectorLen(o)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# FlatVariable
|
||||||
|
def ControlDepForOp(self, j):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
|
||||||
|
if o != 0:
|
||||||
|
a = self._tab.Vector(o)
|
||||||
|
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# FlatVariable
|
||||||
|
def ControlDepForOpLength(self):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
|
||||||
|
if o != 0:
|
||||||
|
return self._tab.VectorLen(o)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# FlatVariable
|
||||||
|
def ControlDepsForVar(self, j):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||||
|
if o != 0:
|
||||||
|
a = self._tab.Vector(o)
|
||||||
|
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# FlatVariable
|
||||||
|
def ControlDepsForVarLength(self):
|
||||||
|
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||||
|
if o != 0:
|
||||||
|
return self._tab.VectorLen(o)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def FlatVariableStart(builder): builder.StartObject(10)
|
||||||
def FlatVariableAddId(builder, id): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(id), 0)
|
def FlatVariableAddId(builder, id): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(id), 0)
|
||||||
def FlatVariableAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
def FlatVariableAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||||
def FlatVariableAddDtype(builder, dtype): builder.PrependInt8Slot(2, dtype, 0)
|
def FlatVariableAddDtype(builder, dtype): builder.PrependInt8Slot(2, dtype, 0)
|
||||||
|
@ -99,4 +144,10 @@ def FlatVariableStartShapeVector(builder, numElems): return builder.StartVector(
|
||||||
def FlatVariableAddNdarray(builder, ndarray): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(ndarray), 0)
|
def FlatVariableAddNdarray(builder, ndarray): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(ndarray), 0)
|
||||||
def FlatVariableAddDevice(builder, device): builder.PrependInt32Slot(5, device, 0)
|
def FlatVariableAddDevice(builder, device): builder.PrependInt32Slot(5, device, 0)
|
||||||
def FlatVariableAddVariabletype(builder, variabletype): builder.PrependInt8Slot(6, variabletype, 0)
|
def FlatVariableAddVariabletype(builder, variabletype): builder.PrependInt8Slot(6, variabletype, 0)
|
||||||
|
def FlatVariableAddControlDeps(builder, controlDeps): builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(controlDeps), 0)
|
||||||
|
def FlatVariableStartControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||||
|
def FlatVariableAddControlDepForOp(builder, controlDepForOp): builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepForOp), 0)
|
||||||
|
def FlatVariableStartControlDepForOpVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||||
|
def FlatVariableAddControlDepsForVar(builder, controlDepsForVar): builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepsForVar), 0)
|
||||||
|
def FlatVariableStartControlDepsForVarVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||||
def FlatVariableEnd(builder): return builder.EndObject()
|
def FlatVariableEnd(builder): return builder.EndObject()
|
||||||
|
|
|
@ -35,7 +35,10 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
VT_OUTPUTNAMES = 34,
|
VT_OUTPUTNAMES = 34,
|
||||||
VT_OPNAME = 36,
|
VT_OPNAME = 36,
|
||||||
VT_OUTPUTTYPES = 38,
|
VT_OUTPUTTYPES = 38,
|
||||||
VT_SCALAR = 40
|
VT_SCALAR = 40,
|
||||||
|
VT_CONTROLDEPS = 42,
|
||||||
|
VT_VARCONTROLDEPS = 44,
|
||||||
|
VT_CONTROLDEPFOR = 46
|
||||||
};
|
};
|
||||||
int32_t id() const {
|
int32_t id() const {
|
||||||
return GetField<int32_t>(VT_ID, 0);
|
return GetField<int32_t>(VT_ID, 0);
|
||||||
|
@ -94,6 +97,15 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
const FlatArray *scalar() const {
|
const FlatArray *scalar() const {
|
||||||
return GetPointer<const FlatArray *>(VT_SCALAR);
|
return GetPointer<const FlatArray *>(VT_SCALAR);
|
||||||
}
|
}
|
||||||
|
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps() const {
|
||||||
|
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPS);
|
||||||
|
}
|
||||||
|
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps() const {
|
||||||
|
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_VARCONTROLDEPS);
|
||||||
|
}
|
||||||
|
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor() const {
|
||||||
|
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOR);
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int32_t>(verifier, VT_ID) &&
|
VerifyField<int32_t>(verifier, VT_ID) &&
|
||||||
|
@ -132,6 +144,15 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
verifier.VerifyVector(outputTypes()) &&
|
verifier.VerifyVector(outputTypes()) &&
|
||||||
VerifyOffset(verifier, VT_SCALAR) &&
|
VerifyOffset(verifier, VT_SCALAR) &&
|
||||||
verifier.VerifyTable(scalar()) &&
|
verifier.VerifyTable(scalar()) &&
|
||||||
|
VerifyOffset(verifier, VT_CONTROLDEPS) &&
|
||||||
|
verifier.VerifyVector(controlDeps()) &&
|
||||||
|
verifier.VerifyVectorOfStrings(controlDeps()) &&
|
||||||
|
VerifyOffset(verifier, VT_VARCONTROLDEPS) &&
|
||||||
|
verifier.VerifyVector(varControlDeps()) &&
|
||||||
|
verifier.VerifyVectorOfStrings(varControlDeps()) &&
|
||||||
|
VerifyOffset(verifier, VT_CONTROLDEPFOR) &&
|
||||||
|
verifier.VerifyVector(controlDepFor()) &&
|
||||||
|
verifier.VerifyVectorOfStrings(controlDepFor()) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -196,6 +217,15 @@ struct FlatNodeBuilder {
|
||||||
void add_scalar(flatbuffers::Offset<FlatArray> scalar) {
|
void add_scalar(flatbuffers::Offset<FlatArray> scalar) {
|
||||||
fbb_.AddOffset(FlatNode::VT_SCALAR, scalar);
|
fbb_.AddOffset(FlatNode::VT_SCALAR, scalar);
|
||||||
}
|
}
|
||||||
|
void add_controlDeps(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps) {
|
||||||
|
fbb_.AddOffset(FlatNode::VT_CONTROLDEPS, controlDeps);
|
||||||
|
}
|
||||||
|
void add_varControlDeps(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps) {
|
||||||
|
fbb_.AddOffset(FlatNode::VT_VARCONTROLDEPS, varControlDeps);
|
||||||
|
}
|
||||||
|
void add_controlDepFor(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor) {
|
||||||
|
fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor);
|
||||||
|
}
|
||||||
explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
|
@ -228,9 +258,15 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNode(
|
||||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputNames = 0,
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputNames = 0,
|
||||||
flatbuffers::Offset<flatbuffers::String> opName = 0,
|
flatbuffers::Offset<flatbuffers::String> opName = 0,
|
||||||
flatbuffers::Offset<flatbuffers::Vector<int8_t>> outputTypes = 0,
|
flatbuffers::Offset<flatbuffers::Vector<int8_t>> outputTypes = 0,
|
||||||
flatbuffers::Offset<FlatArray> scalar = 0) {
|
flatbuffers::Offset<FlatArray> scalar = 0,
|
||||||
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
|
||||||
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps = 0,
|
||||||
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor = 0) {
|
||||||
FlatNodeBuilder builder_(_fbb);
|
FlatNodeBuilder builder_(_fbb);
|
||||||
builder_.add_opNum(opNum);
|
builder_.add_opNum(opNum);
|
||||||
|
builder_.add_controlDepFor(controlDepFor);
|
||||||
|
builder_.add_varControlDeps(varControlDeps);
|
||||||
|
builder_.add_controlDeps(controlDeps);
|
||||||
builder_.add_scalar(scalar);
|
builder_.add_scalar(scalar);
|
||||||
builder_.add_outputTypes(outputTypes);
|
builder_.add_outputTypes(outputTypes);
|
||||||
builder_.add_opName(opName);
|
builder_.add_opName(opName);
|
||||||
|
@ -272,7 +308,10 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
|
||||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *outputNames = nullptr,
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *outputNames = nullptr,
|
||||||
const char *opName = nullptr,
|
const char *opName = nullptr,
|
||||||
const std::vector<int8_t> *outputTypes = nullptr,
|
const std::vector<int8_t> *outputTypes = nullptr,
|
||||||
flatbuffers::Offset<FlatArray> scalar = 0) {
|
flatbuffers::Offset<FlatArray> scalar = 0,
|
||||||
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
|
||||||
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps = nullptr,
|
||||||
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor = nullptr) {
|
||||||
return nd4j::graph::CreateFlatNode(
|
return nd4j::graph::CreateFlatNode(
|
||||||
_fbb,
|
_fbb,
|
||||||
id,
|
id,
|
||||||
|
@ -293,7 +332,10 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
|
||||||
outputNames ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*outputNames) : 0,
|
outputNames ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*outputNames) : 0,
|
||||||
opName ? _fbb.CreateString(opName) : 0,
|
opName ? _fbb.CreateString(opName) : 0,
|
||||||
outputTypes ? _fbb.CreateVector<int8_t>(*outputTypes) : 0,
|
outputTypes ? _fbb.CreateVector<int8_t>(*outputTypes) : 0,
|
||||||
scalar);
|
scalar,
|
||||||
|
controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0,
|
||||||
|
varControlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*varControlDeps) : 0,
|
||||||
|
controlDepFor ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepFor) : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) {
|
inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) {
|
||||||
|
|
|
@ -344,11 +344,65 @@ nd4j.graph.FlatNode.prototype.scalar = function(obj) {
|
||||||
return offset ? (obj || new nd4j.graph.FlatArray).__init(this.bb.__indirect(this.bb_pos + offset), this.bb) : null;
|
return offset ? (obj || new nd4j.graph.FlatArray).__init(this.bb.__indirect(this.bb_pos + offset), this.bb) : null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {number} index
|
||||||
|
* @param {flatbuffers.Encoding=} optionalEncoding
|
||||||
|
* @returns {string|Uint8Array}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.prototype.controlDeps = function(index, optionalEncoding) {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 42);
|
||||||
|
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @returns {number}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.prototype.controlDepsLength = function() {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 42);
|
||||||
|
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {number} index
|
||||||
|
* @param {flatbuffers.Encoding=} optionalEncoding
|
||||||
|
* @returns {string|Uint8Array}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.prototype.varControlDeps = function(index, optionalEncoding) {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 44);
|
||||||
|
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @returns {number}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.prototype.varControlDepsLength = function() {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 44);
|
||||||
|
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {number} index
|
||||||
|
* @param {flatbuffers.Encoding=} optionalEncoding
|
||||||
|
* @returns {string|Uint8Array}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.prototype.controlDepFor = function(index, optionalEncoding) {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 46);
|
||||||
|
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @returns {number}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.prototype.controlDepForLength = function() {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 46);
|
||||||
|
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatNode.startFlatNode = function(builder) {
|
nd4j.graph.FlatNode.startFlatNode = function(builder) {
|
||||||
builder.startObject(19);
|
builder.startObject(22);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -713,6 +767,93 @@ nd4j.graph.FlatNode.addScalar = function(builder, scalarOffset) {
|
||||||
builder.addFieldOffset(18, scalarOffset, 0);
|
builder.addFieldOffset(18, scalarOffset, 0);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {flatbuffers.Offset} controlDepsOffset
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.addControlDeps = function(builder, controlDepsOffset) {
|
||||||
|
builder.addFieldOffset(19, controlDepsOffset, 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {Array.<flatbuffers.Offset>} data
|
||||||
|
* @returns {flatbuffers.Offset}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.createControlDepsVector = function(builder, data) {
|
||||||
|
builder.startVector(4, data.length, 4);
|
||||||
|
for (var i = data.length - 1; i >= 0; i--) {
|
||||||
|
builder.addOffset(data[i]);
|
||||||
|
}
|
||||||
|
return builder.endVector();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {number} numElems
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.startControlDepsVector = function(builder, numElems) {
|
||||||
|
builder.startVector(4, numElems, 4);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {flatbuffers.Offset} varControlDepsOffset
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.addVarControlDeps = function(builder, varControlDepsOffset) {
|
||||||
|
builder.addFieldOffset(20, varControlDepsOffset, 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {Array.<flatbuffers.Offset>} data
|
||||||
|
* @returns {flatbuffers.Offset}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.createVarControlDepsVector = function(builder, data) {
|
||||||
|
builder.startVector(4, data.length, 4);
|
||||||
|
for (var i = data.length - 1; i >= 0; i--) {
|
||||||
|
builder.addOffset(data[i]);
|
||||||
|
}
|
||||||
|
return builder.endVector();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {number} numElems
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.startVarControlDepsVector = function(builder, numElems) {
|
||||||
|
builder.startVector(4, numElems, 4);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {flatbuffers.Offset} controlDepForOffset
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.addControlDepFor = function(builder, controlDepForOffset) {
|
||||||
|
builder.addFieldOffset(21, controlDepForOffset, 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {Array.<flatbuffers.Offset>} data
|
||||||
|
* @returns {flatbuffers.Offset}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.createControlDepForVector = function(builder, data) {
|
||||||
|
builder.startVector(4, data.length, 4);
|
||||||
|
for (var i = data.length - 1; i >= 0; i--) {
|
||||||
|
builder.addOffset(data[i]);
|
||||||
|
}
|
||||||
|
return builder.endVector();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {number} numElems
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatNode.startControlDepForVector = function(builder, numElems) {
|
||||||
|
builder.startVector(4, numElems, 4);
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
* @returns {flatbuffers.Offset}
|
* @returns {flatbuffers.Offset}
|
||||||
|
|
|
@ -57,7 +57,10 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
VT_SHAPE = 10,
|
VT_SHAPE = 10,
|
||||||
VT_NDARRAY = 12,
|
VT_NDARRAY = 12,
|
||||||
VT_DEVICE = 14,
|
VT_DEVICE = 14,
|
||||||
VT_VARIABLETYPE = 16
|
VT_VARIABLETYPE = 16,
|
||||||
|
VT_CONTROLDEPS = 18,
|
||||||
|
VT_CONTROLDEPFOROP = 20,
|
||||||
|
VT_CONTROLDEPSFORVAR = 22
|
||||||
};
|
};
|
||||||
const IntPair *id() const {
|
const IntPair *id() const {
|
||||||
return GetPointer<const IntPair *>(VT_ID);
|
return GetPointer<const IntPair *>(VT_ID);
|
||||||
|
@ -80,6 +83,15 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
VarType variabletype() const {
|
VarType variabletype() const {
|
||||||
return static_cast<VarType>(GetField<int8_t>(VT_VARIABLETYPE, 0));
|
return static_cast<VarType>(GetField<int8_t>(VT_VARIABLETYPE, 0));
|
||||||
}
|
}
|
||||||
|
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps() const {
|
||||||
|
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPS);
|
||||||
|
}
|
||||||
|
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepForOp() const {
|
||||||
|
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOROP);
|
||||||
|
}
|
||||||
|
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepsForVar() const {
|
||||||
|
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPSFORVAR);
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyOffset(verifier, VT_ID) &&
|
VerifyOffset(verifier, VT_ID) &&
|
||||||
|
@ -93,6 +105,15 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
verifier.VerifyTable(ndarray()) &&
|
verifier.VerifyTable(ndarray()) &&
|
||||||
VerifyField<int32_t>(verifier, VT_DEVICE) &&
|
VerifyField<int32_t>(verifier, VT_DEVICE) &&
|
||||||
VerifyField<int8_t>(verifier, VT_VARIABLETYPE) &&
|
VerifyField<int8_t>(verifier, VT_VARIABLETYPE) &&
|
||||||
|
VerifyOffset(verifier, VT_CONTROLDEPS) &&
|
||||||
|
verifier.VerifyVector(controlDeps()) &&
|
||||||
|
verifier.VerifyVectorOfStrings(controlDeps()) &&
|
||||||
|
VerifyOffset(verifier, VT_CONTROLDEPFOROP) &&
|
||||||
|
verifier.VerifyVector(controlDepForOp()) &&
|
||||||
|
verifier.VerifyVectorOfStrings(controlDepForOp()) &&
|
||||||
|
VerifyOffset(verifier, VT_CONTROLDEPSFORVAR) &&
|
||||||
|
verifier.VerifyVector(controlDepsForVar()) &&
|
||||||
|
verifier.VerifyVectorOfStrings(controlDepsForVar()) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -121,6 +142,15 @@ struct FlatVariableBuilder {
|
||||||
void add_variabletype(VarType variabletype) {
|
void add_variabletype(VarType variabletype) {
|
||||||
fbb_.AddElement<int8_t>(FlatVariable::VT_VARIABLETYPE, static_cast<int8_t>(variabletype), 0);
|
fbb_.AddElement<int8_t>(FlatVariable::VT_VARIABLETYPE, static_cast<int8_t>(variabletype), 0);
|
||||||
}
|
}
|
||||||
|
void add_controlDeps(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps) {
|
||||||
|
fbb_.AddOffset(FlatVariable::VT_CONTROLDEPS, controlDeps);
|
||||||
|
}
|
||||||
|
void add_controlDepForOp(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepForOp) {
|
||||||
|
fbb_.AddOffset(FlatVariable::VT_CONTROLDEPFOROP, controlDepForOp);
|
||||||
|
}
|
||||||
|
void add_controlDepsForVar(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepsForVar) {
|
||||||
|
fbb_.AddOffset(FlatVariable::VT_CONTROLDEPSFORVAR, controlDepsForVar);
|
||||||
|
}
|
||||||
explicit FlatVariableBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit FlatVariableBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
|
@ -141,8 +171,14 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariable(
|
||||||
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
||||||
flatbuffers::Offset<FlatArray> ndarray = 0,
|
flatbuffers::Offset<FlatArray> ndarray = 0,
|
||||||
int32_t device = 0,
|
int32_t device = 0,
|
||||||
VarType variabletype = VarType_VARIABLE) {
|
VarType variabletype = VarType_VARIABLE,
|
||||||
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
|
||||||
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepForOp = 0,
|
||||||
|
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepsForVar = 0) {
|
||||||
FlatVariableBuilder builder_(_fbb);
|
FlatVariableBuilder builder_(_fbb);
|
||||||
|
builder_.add_controlDepsForVar(controlDepsForVar);
|
||||||
|
builder_.add_controlDepForOp(controlDepForOp);
|
||||||
|
builder_.add_controlDeps(controlDeps);
|
||||||
builder_.add_device(device);
|
builder_.add_device(device);
|
||||||
builder_.add_ndarray(ndarray);
|
builder_.add_ndarray(ndarray);
|
||||||
builder_.add_shape(shape);
|
builder_.add_shape(shape);
|
||||||
|
@ -161,7 +197,10 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
|
||||||
const std::vector<int64_t> *shape = nullptr,
|
const std::vector<int64_t> *shape = nullptr,
|
||||||
flatbuffers::Offset<FlatArray> ndarray = 0,
|
flatbuffers::Offset<FlatArray> ndarray = 0,
|
||||||
int32_t device = 0,
|
int32_t device = 0,
|
||||||
VarType variabletype = VarType_VARIABLE) {
|
VarType variabletype = VarType_VARIABLE,
|
||||||
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
|
||||||
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepForOp = nullptr,
|
||||||
|
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepsForVar = nullptr) {
|
||||||
return nd4j::graph::CreateFlatVariable(
|
return nd4j::graph::CreateFlatVariable(
|
||||||
_fbb,
|
_fbb,
|
||||||
id,
|
id,
|
||||||
|
@ -170,7 +209,10 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
|
||||||
shape ? _fbb.CreateVector<int64_t>(*shape) : 0,
|
shape ? _fbb.CreateVector<int64_t>(*shape) : 0,
|
||||||
ndarray,
|
ndarray,
|
||||||
device,
|
device,
|
||||||
variabletype);
|
variabletype,
|
||||||
|
controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0,
|
||||||
|
controlDepForOp ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepForOp) : 0,
|
||||||
|
controlDepsForVar ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepsForVar) : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const nd4j::graph::FlatVariable *GetFlatVariable(const void *buf) {
|
inline const nd4j::graph::FlatVariable *GetFlatVariable(const void *buf) {
|
||||||
|
|
|
@ -125,11 +125,65 @@ nd4j.graph.FlatVariable.prototype.variabletype = function() {
|
||||||
return offset ? /** @type {nd4j.graph.VarType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.VarType.VARIABLE;
|
return offset ? /** @type {nd4j.graph.VarType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.VarType.VARIABLE;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {number} index
|
||||||
|
* @param {flatbuffers.Encoding=} optionalEncoding
|
||||||
|
* @returns {string|Uint8Array}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.prototype.controlDeps = function(index, optionalEncoding) {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 18);
|
||||||
|
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @returns {number}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.prototype.controlDepsLength = function() {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 18);
|
||||||
|
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {number} index
|
||||||
|
* @param {flatbuffers.Encoding=} optionalEncoding
|
||||||
|
* @returns {string|Uint8Array}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.prototype.controlDepForOp = function(index, optionalEncoding) {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 20);
|
||||||
|
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @returns {number}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.prototype.controlDepForOpLength = function() {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 20);
|
||||||
|
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {number} index
|
||||||
|
* @param {flatbuffers.Encoding=} optionalEncoding
|
||||||
|
* @returns {string|Uint8Array}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.prototype.controlDepsForVar = function(index, optionalEncoding) {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 22);
|
||||||
|
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @returns {number}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.prototype.controlDepsForVarLength = function() {
|
||||||
|
var offset = this.bb.__offset(this.bb_pos, 22);
|
||||||
|
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
*/
|
*/
|
||||||
nd4j.graph.FlatVariable.startFlatVariable = function(builder) {
|
nd4j.graph.FlatVariable.startFlatVariable = function(builder) {
|
||||||
builder.startObject(7);
|
builder.startObject(10);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -209,6 +263,93 @@ nd4j.graph.FlatVariable.addVariabletype = function(builder, variabletype) {
|
||||||
builder.addFieldInt8(6, variabletype, nd4j.graph.VarType.VARIABLE);
|
builder.addFieldInt8(6, variabletype, nd4j.graph.VarType.VARIABLE);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {flatbuffers.Offset} controlDepsOffset
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.addControlDeps = function(builder, controlDepsOffset) {
|
||||||
|
builder.addFieldOffset(7, controlDepsOffset, 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {Array.<flatbuffers.Offset>} data
|
||||||
|
* @returns {flatbuffers.Offset}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.createControlDepsVector = function(builder, data) {
|
||||||
|
builder.startVector(4, data.length, 4);
|
||||||
|
for (var i = data.length - 1; i >= 0; i--) {
|
||||||
|
builder.addOffset(data[i]);
|
||||||
|
}
|
||||||
|
return builder.endVector();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {number} numElems
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.startControlDepsVector = function(builder, numElems) {
|
||||||
|
builder.startVector(4, numElems, 4);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {flatbuffers.Offset} controlDepForOpOffset
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.addControlDepForOp = function(builder, controlDepForOpOffset) {
|
||||||
|
builder.addFieldOffset(8, controlDepForOpOffset, 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {Array.<flatbuffers.Offset>} data
|
||||||
|
* @returns {flatbuffers.Offset}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.createControlDepForOpVector = function(builder, data) {
|
||||||
|
builder.startVector(4, data.length, 4);
|
||||||
|
for (var i = data.length - 1; i >= 0; i--) {
|
||||||
|
builder.addOffset(data[i]);
|
||||||
|
}
|
||||||
|
return builder.endVector();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {number} numElems
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.startControlDepForOpVector = function(builder, numElems) {
|
||||||
|
builder.startVector(4, numElems, 4);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {flatbuffers.Offset} controlDepsForVarOffset
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.addControlDepsForVar = function(builder, controlDepsForVarOffset) {
|
||||||
|
builder.addFieldOffset(9, controlDepsForVarOffset, 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {Array.<flatbuffers.Offset>} data
|
||||||
|
* @returns {flatbuffers.Offset}
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.createControlDepsForVarVector = function(builder, data) {
|
||||||
|
builder.startVector(4, data.length, 4);
|
||||||
|
for (var i = data.length - 1; i >= 0; i--) {
|
||||||
|
builder.addOffset(data[i]);
|
||||||
|
}
|
||||||
|
return builder.endVector();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {flatbuffers.Builder} builder
|
||||||
|
* @param {number} numElems
|
||||||
|
*/
|
||||||
|
nd4j.graph.FlatVariable.startControlDepsForVarVector = function(builder, numElems) {
|
||||||
|
builder.startVector(4, numElems, 4);
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param {flatbuffers.Builder} builder
|
* @param {flatbuffers.Builder} builder
|
||||||
* @returns {flatbuffers.Offset}
|
* @returns {flatbuffers.Offset}
|
||||||
|
|
|
@ -52,6 +52,12 @@ table FlatNode {
|
||||||
|
|
||||||
//Scalar value - used for scalar ops. Should be single value only.
|
//Scalar value - used for scalar ops. Should be single value only.
|
||||||
scalar:FlatArray;
|
scalar:FlatArray;
|
||||||
|
|
||||||
|
//Control dependencies
|
||||||
|
controlDeps:[string];
|
||||||
|
varControlDeps:[string];
|
||||||
|
controlDepFor:[string];
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
root_type FlatNode;
|
root_type FlatNode;
|
|
@ -37,6 +37,10 @@ table FlatVariable {
|
||||||
|
|
||||||
device:int; // default is -1, which means _auto_
|
device:int; // default is -1, which means _auto_
|
||||||
variabletype:VarType;
|
variabletype:VarType;
|
||||||
|
|
||||||
|
controlDeps:[string];
|
||||||
|
controlDepForOp:[string];
|
||||||
|
controlDepsForVar:[string];
|
||||||
}
|
}
|
||||||
|
|
||||||
root_type FlatVariable;
|
root_type FlatVariable;
|
|
@ -659,7 +659,8 @@ public abstract class DifferentialFunction {
|
||||||
if(sameDiff == null)
|
if(sameDiff == null)
|
||||||
this.ownName = UUID.randomUUID().toString();
|
this.ownName = UUID.randomUUID().toString();
|
||||||
else {
|
else {
|
||||||
this.ownName = sameDiff.getOpName(opName());
|
String n = sameDiff.getOpName(opName());
|
||||||
|
this.ownName = n;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(sameDiff != null)
|
if(sameDiff != null)
|
||||||
|
@ -696,30 +697,11 @@ public abstract class DifferentialFunction {
|
||||||
}
|
}
|
||||||
|
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
private INDArray getX() {
|
public INDArray getInputArgument(int index){
|
||||||
INDArray ret = sameDiff.getArrForVarName(args()[0].getVarName());
|
//Subclasses should implement this
|
||||||
return ret;
|
throw new UnsupportedOperationException("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
@JsonIgnore
|
|
||||||
private INDArray getY() {
|
|
||||||
if(args().length > 1) {
|
|
||||||
INDArray ret = sameDiff.getArrForVarName(args()[1].getVarName());
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@JsonIgnore
|
|
||||||
private INDArray getZ() {
|
|
||||||
if(isInPlace())
|
|
||||||
return getX();
|
|
||||||
SDVariable opId = outputVariables()[0];
|
|
||||||
INDArray ret = opId.getArr();
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -860,4 +842,8 @@ public abstract class DifferentialFunction {
|
||||||
|
|
||||||
public int getNumOutputs(){return -1;}
|
public int getNumOutputs(){return -1;}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear the input and output INDArrays, if any are set
|
||||||
|
*/
|
||||||
|
public abstract void clearArrays();
|
||||||
}
|
}
|
||||||
|
|
|
@ -982,8 +982,8 @@ public class DifferentialFunctionFactory {
|
||||||
return new CumProdBp(sameDiff(), in, grad, exclusive, reverse, axis).outputVariable();
|
return new CumProdBp(sameDiff(), in, grad, exclusive, reverse, axis).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable biasAdd(SDVariable input, SDVariable bias) {
|
public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) {
|
||||||
return new BiasAdd(sameDiff(), input, bias).outputVariable();
|
return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) {
|
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) {
|
||||||
|
|
|
@ -24,6 +24,7 @@ import lombok.Getter;
|
||||||
import org.nd4j.autodiff.listeners.Listener;
|
import org.nd4j.autodiff.listeners.Listener;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.evaluation.IMetric;
|
import org.nd4j.evaluation.IMetric;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
@ -319,6 +320,7 @@ public class History {
|
||||||
* Gets the training evaluations ran during the last epoch
|
* Gets the training evaluations ran during the last epoch
|
||||||
*/
|
*/
|
||||||
public EvaluationRecord finalTrainingEvaluations(){
|
public EvaluationRecord finalTrainingEvaluations(){
|
||||||
|
Preconditions.checkState(!trainingHistory.isEmpty(), "Cannot get final training evaluation - history is empty");
|
||||||
return trainingHistory.get(trainingHistory.size() - 1);
|
return trainingHistory.get(trainingHistory.size() - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -326,6 +328,7 @@ public class History {
|
||||||
* Gets the validation evaluations ran during the last epoch
|
* Gets the validation evaluations ran during the last epoch
|
||||||
*/
|
*/
|
||||||
public EvaluationRecord finalValidationEvaluations(){
|
public EvaluationRecord finalValidationEvaluations(){
|
||||||
|
Preconditions.checkState(!validationHistory.isEmpty(), "Cannot get final validation evaluation - history is empty");
|
||||||
return validationHistory.get(validationHistory.size() - 1);
|
return validationHistory.get(validationHistory.size() - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,34 +16,23 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff;
|
package org.nd4j.autodiff.samediff;
|
||||||
|
|
||||||
import java.util.Objects;
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import onnx.Onnx;
|
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
|
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
|
||||||
import org.nd4j.weightinit.WeightInitScheme;
|
import org.nd4j.weightinit.WeightInitScheme;
|
||||||
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
|
||||||
import org.tensorflow.framework.AttrValue;
|
|
||||||
import org.tensorflow.framework.GraphDef;
|
|
||||||
import org.tensorflow.framework.NodeDef;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -167,6 +156,10 @@ public class SDVariable implements Serializable {
|
||||||
if(sameDiff.arrayAlreadyExistsForVarName(getVarName()))
|
if(sameDiff.arrayAlreadyExistsForVarName(getVarName()))
|
||||||
return sameDiff.getArrForVarName(getVarName());
|
return sameDiff.getArrForVarName(getVarName());
|
||||||
|
|
||||||
|
if(variableType == VariableType.ARRAY){
|
||||||
|
throw new UnsupportedOperationException("Cannot get array for ARRAY type SDVariable - use SDVariable.exec or SameDiff.output instead");
|
||||||
|
}
|
||||||
|
|
||||||
//initialize value if it's actually a scalar constant (zero or 1 typically...)
|
//initialize value if it's actually a scalar constant (zero or 1 typically...)
|
||||||
if(variableType == VariableType.VARIABLE && weightInitScheme != null && shape != null){
|
if(variableType == VariableType.VARIABLE && weightInitScheme != null && shape != null){
|
||||||
INDArray arr = weightInitScheme.create(dataType, shape);
|
INDArray arr = weightInitScheme.create(dataType, shape);
|
||||||
|
@ -211,8 +204,8 @@ public class SDVariable implements Serializable {
|
||||||
* created automatically when training is performed.
|
* created automatically when training is performed.
|
||||||
*/
|
*/
|
||||||
public SDVariable getGradient() {
|
public SDVariable getGradient() {
|
||||||
Preconditions.checkState(dataType().isFPType(), "Cannot get gradient of %s variable \"%s\": only floating" +
|
Preconditions.checkState(dataType().isFPType(), "Cannot get gradient of %s datatype variable \"%s\": only floating" +
|
||||||
" point variables have gradients", getVarName(), dataType());
|
" point variables have gradients", dataType(), getVarName());
|
||||||
return sameDiff.getGradForVariable(getVarName());
|
return sameDiff.getGradForVariable(getVarName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -230,7 +223,7 @@ public class SDVariable implements Serializable {
|
||||||
}
|
}
|
||||||
|
|
||||||
long[] initialShape = sameDiff.getShapeForVarName(getVarName());
|
long[] initialShape = sameDiff.getShapeForVarName(getVarName());
|
||||||
if(initialShape == null) {
|
if(initialShape == null && variableType != VariableType.ARRAY) {
|
||||||
val arr = getArr();
|
val arr = getArr();
|
||||||
if(arr != null)
|
if(arr != null)
|
||||||
return arr.shape();
|
return arr.shape();
|
||||||
|
@ -254,7 +247,7 @@ public class SDVariable implements Serializable {
|
||||||
public DataType dataType() {
|
public DataType dataType() {
|
||||||
if(this.dataType == null){
|
if(this.dataType == null){
|
||||||
//Try to infer datatype instead of returning null
|
//Try to infer datatype instead of returning null
|
||||||
if(getArr() != null){
|
if(variableType != VariableType.ARRAY && getArr() != null){
|
||||||
this.dataType = getArr().dataType();
|
this.dataType = getArr().dataType();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1518,26 +1511,59 @@ public class SDVariable implements Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a control dependency for this variable on the specified variable.<br>
|
* Add a control dependency for this variable on the specified variable.<br>
|
||||||
* Control depnedencies can be used to enforce the execution order.
|
* Control dependencies can be used to enforce the execution order.
|
||||||
* For example, if a control dependency X->Y exists, then Y will only be executed after X is executed - even
|
* For example, if a control dependency X->Y exists, then Y will only be executed after X is executed - even
|
||||||
* if Y wouldn't normally depend on the result/values of X.
|
* if Y wouldn't normally depend on the result/values of X.
|
||||||
*
|
*
|
||||||
* @param controlDependency Control dependency to add for this variable
|
* @param controlDependency Control dependency to add for this variable
|
||||||
*/
|
*/
|
||||||
public void addControlDependency(SDVariable controlDependency){
|
public void addControlDependency(SDVariable controlDependency){
|
||||||
String cdN = controlDependency.getVarName();
|
Variable vThis = sameDiff.getVariables().get(getVarName());
|
||||||
String n = this.getVarName();
|
Variable vCD = sameDiff.getVariables().get(controlDependency.getVarName());
|
||||||
Variable v = sameDiff.getVariables().get(n);
|
|
||||||
if(v.getControlDeps() == null)
|
|
||||||
v.setControlDeps(new ArrayList<String>());
|
|
||||||
if(!v.getControlDeps().contains(cdN))
|
|
||||||
v.getControlDeps().add(cdN);
|
|
||||||
|
|
||||||
Variable v2 = sameDiff.getVariables().get(cdN);
|
//If possible: add control dependency on ops
|
||||||
if(v2.getControlDepsForVar() == null)
|
if(vThis.getOutputOfOp() != null && vCD.getOutputOfOp() != null ){
|
||||||
v2.setControlDepsForVar(new ArrayList<String>());
|
//Op -> Op case
|
||||||
if(!v2.getControlDepsForVar().contains(n))
|
SameDiffOp oThis = sameDiff.getOps().get(vThis.getOutputOfOp());
|
||||||
v2.getControlDepsForVar().add(n);
|
SameDiffOp oCD = sameDiff.getOps().get(vCD.getOutputOfOp());
|
||||||
|
|
||||||
|
if(oThis.getControlDeps() == null)
|
||||||
|
oThis.setControlDeps(new ArrayList<String>());
|
||||||
|
if(!oThis.getControlDeps().contains(oCD.getName()))
|
||||||
|
oThis.getControlDeps().add(oCD.getName());
|
||||||
|
|
||||||
|
if(oCD.getControlDepFor() == null)
|
||||||
|
oCD.setControlDepFor(new ArrayList<String>());
|
||||||
|
if(!oCD.getControlDepFor().contains(oThis.getName()))
|
||||||
|
oCD.getControlDepFor().add(oThis.getName());
|
||||||
|
} else {
|
||||||
|
if(vThis.getOutputOfOp() != null){
|
||||||
|
//const/ph -> op case
|
||||||
|
SameDiffOp oThis = sameDiff.getOps().get(vThis.getOutputOfOp());
|
||||||
|
|
||||||
|
if(oThis.getVarControlDeps() == null)
|
||||||
|
oThis.setVarControlDeps(new ArrayList<String>());
|
||||||
|
|
||||||
|
if(!oThis.getVarControlDeps().contains(vCD.getName()))
|
||||||
|
oThis.getVarControlDeps().add(vCD.getName());
|
||||||
|
|
||||||
|
if(vCD.getControlDepsForOp() == null)
|
||||||
|
vCD.setControlDepsForOp(new ArrayList<String>());
|
||||||
|
if(!vCD.getControlDepsForOp().contains(oThis.getName()))
|
||||||
|
vCD.getControlDepsForOp().add(oThis.getName());
|
||||||
|
} else {
|
||||||
|
//const/ph -> const/ph case
|
||||||
|
if(vThis.getControlDeps() == null)
|
||||||
|
vThis.setControlDeps(new ArrayList<String>());
|
||||||
|
if(!vThis.getControlDeps().contains(vCD.getName()))
|
||||||
|
vThis.getControlDeps().add(vCD.getName());
|
||||||
|
|
||||||
|
if(vCD.getControlDepsForVar() == null)
|
||||||
|
vCD.setControlDepsForVar(new ArrayList<String>());
|
||||||
|
if(!vCD.getControlDepsForVar().contains(vThis.getName()))
|
||||||
|
vCD.getControlDepsForVar().add(vThis.getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -16,58 +16,16 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff;
|
package org.nd4j.autodiff.samediff;
|
||||||
|
|
||||||
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
|
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
import java.io.BufferedInputStream;
|
import lombok.*;
|
||||||
import java.io.BufferedOutputStream;
|
|
||||||
import java.io.DataOutputStream;
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.FileOutputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.io.OutputStream;
|
|
||||||
import java.lang.reflect.Method;
|
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.IdentityHashMap;
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.LinkedHashSet;
|
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Queue;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.Stack;
|
|
||||||
import java.util.UUID;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
import java.util.regex.Matcher;
|
|
||||||
import java.util.regex.Pattern;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.NonNull;
|
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
|
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
|
||||||
import org.nd4j.autodiff.execution.conf.OutputMode;
|
import org.nd4j.autodiff.execution.conf.OutputMode;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
|
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
|
||||||
import org.nd4j.autodiff.listeners.At;
|
import org.nd4j.autodiff.listeners.*;
|
||||||
import org.nd4j.autodiff.listeners.Listener;
|
|
||||||
import org.nd4j.autodiff.listeners.ListenerResponse;
|
|
||||||
import org.nd4j.autodiff.listeners.Loss;
|
|
||||||
import org.nd4j.autodiff.listeners.Operation;
|
|
||||||
import org.nd4j.autodiff.listeners.impl.HistoryListener;
|
import org.nd4j.autodiff.listeners.impl.HistoryListener;
|
||||||
import org.nd4j.autodiff.listeners.records.History;
|
import org.nd4j.autodiff.listeners.records.History;
|
||||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||||
|
@ -75,34 +33,14 @@ import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
|
||||||
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
|
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
|
||||||
import org.nd4j.autodiff.samediff.config.FitConfig;
|
import org.nd4j.autodiff.samediff.config.FitConfig;
|
||||||
import org.nd4j.autodiff.samediff.config.OutputConfig;
|
import org.nd4j.autodiff.samediff.config.OutputConfig;
|
||||||
import org.nd4j.autodiff.samediff.internal.AbstractSession;
|
import org.nd4j.autodiff.samediff.internal.*;
|
||||||
import org.nd4j.autodiff.samediff.internal.DataTypesSession;
|
import org.nd4j.autodiff.samediff.ops.*;
|
||||||
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
|
||||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
|
||||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
|
||||||
import org.nd4j.autodiff.samediff.ops.SDBaseOps;
|
|
||||||
import org.nd4j.autodiff.samediff.ops.SDBitwise;
|
|
||||||
import org.nd4j.autodiff.samediff.ops.SDCNN;
|
|
||||||
import org.nd4j.autodiff.samediff.ops.SDImage;
|
|
||||||
import org.nd4j.autodiff.samediff.ops.SDLoss;
|
|
||||||
import org.nd4j.autodiff.samediff.ops.SDMath;
|
|
||||||
import org.nd4j.autodiff.samediff.ops.SDNN;
|
|
||||||
import org.nd4j.autodiff.samediff.ops.SDRNN;
|
|
||||||
import org.nd4j.autodiff.samediff.ops.SDRandom;
|
|
||||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.evaluation.classification.ROC;
|
import org.nd4j.evaluation.classification.ROC;
|
||||||
import org.nd4j.graph.ExecutionMode;
|
import org.nd4j.graph.*;
|
||||||
import org.nd4j.graph.FlatArray;
|
|
||||||
import org.nd4j.graph.FlatConfiguration;
|
|
||||||
import org.nd4j.graph.FlatGraph;
|
|
||||||
import org.nd4j.graph.FlatNode;
|
|
||||||
import org.nd4j.graph.FlatVariable;
|
|
||||||
import org.nd4j.graph.IntPair;
|
|
||||||
import org.nd4j.graph.OpType;
|
|
||||||
import org.nd4j.graph.UpdaterState;
|
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
@ -112,8 +50,6 @@ import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.If;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.While;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
|
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
|
||||||
|
@ -136,7 +72,6 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.GradientUpdater;
|
import org.nd4j.linalg.learning.GradientUpdater;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
import org.nd4j.linalg.primitives.AtomicBoolean;
|
import org.nd4j.linalg.primitives.AtomicBoolean;
|
||||||
import org.nd4j.linalg.primitives.AtomicDouble;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
||||||
|
@ -152,6 +87,17 @@ import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
|
||||||
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
|
||||||
|
import java.io.*;
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
import java.util.regex.Matcher;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
|
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
|
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
|
||||||
* <p>
|
* <p>
|
||||||
|
@ -683,7 +629,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
for (val var : variables()) {
|
for (val var : variables()) {
|
||||||
SDVariable clone = var.clone(this);
|
SDVariable clone = var.clone(this);
|
||||||
SDVariable newVar = sameDiff.var(clone);
|
SDVariable newVar = sameDiff.var(clone);
|
||||||
if (var.getArr() != null && var.getVariableType() != VariableType.ARRAY) { //ARRAY type = "activations" - are overwritten anyway
|
if (var.getVariableType() != VariableType.ARRAY && var.getArr() != null ) { //ARRAY type = "activations" - are overwritten anyway
|
||||||
sameDiff.associateArrayWithVariable(var.getArr(), newVar);
|
sameDiff.associateArrayWithVariable(var.getArr(), newVar);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -795,9 +741,9 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param function the function to get the inputs for
|
* @param function the function to get the inputs for
|
||||||
* @return the input ids for a given function
|
* @return the input ids for a given function
|
||||||
*/
|
*/
|
||||||
public String[] getInputsForOp(DifferentialFunction function) {
|
public String[] getInputsForOp(@NonNull DifferentialFunction function) {
|
||||||
if (!ops.containsKey(function.getOwnName()))
|
if (!ops.containsKey(function.getOwnName()))
|
||||||
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
|
throw new ND4JIllegalStateException("Unknown function instance id found: \"" + function.getOwnName() + "\"");
|
||||||
List<String> inputs = ops.get(function.getOwnName()).getInputsToOp();
|
List<String> inputs = ops.get(function.getOwnName()).getInputsToOp();
|
||||||
return inputs == null ? null : inputs.toArray(new String[inputs.size()]);
|
return inputs == null ? null : inputs.toArray(new String[inputs.size()]);
|
||||||
}
|
}
|
||||||
|
@ -1102,12 +1048,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true));
|
constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true));
|
||||||
break;
|
break;
|
||||||
case ARRAY:
|
case ARRAY:
|
||||||
// FIXME: remove this before release
|
throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" +
|
||||||
val session = sessions.get(Thread.currentThread().getId());
|
" this type of variable is calculated ");
|
||||||
val varId = session.newVarId(variable.getVarName(), AbstractSession.OUTER_FRAME, 0, null);
|
|
||||||
session.getNodeOutputs().put(varId, arr);
|
|
||||||
//throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY");
|
|
||||||
break;
|
|
||||||
case PLACEHOLDER:
|
case PLACEHOLDER:
|
||||||
//Validate placeholder shapes:
|
//Validate placeholder shapes:
|
||||||
long[] phShape = variable.placeholderShape();
|
long[] phShape = variable.placeholderShape();
|
||||||
|
@ -2152,11 +2094,32 @@ public class SameDiff extends SDBaseOps {
|
||||||
requiredVars.addAll(l.requiredVariables(this).trainingVariables());
|
requiredVars.addAll(l.requiredVariables(this).trainingVariables());
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayList<Listener> listenersWitHistory = new ArrayList<>(listeners);
|
List<Listener> listenersWitHistory = new ArrayList<>(listeners);
|
||||||
|
for(Listener l : this.listeners){
|
||||||
|
if(!listenersWitHistory.contains(l))
|
||||||
|
listenersWitHistory.add(l);
|
||||||
|
}
|
||||||
listenersWitHistory.add(history);
|
listenersWitHistory.add(history);
|
||||||
|
|
||||||
for (int i = 0; i < numEpochs; i++) {
|
|
||||||
|
|
||||||
|
SameDiff gradInstance = getFunction("grad");
|
||||||
|
if(gradInstance == null){
|
||||||
|
createGradFunction();
|
||||||
|
gradInstance = getFunction("grad");
|
||||||
|
}
|
||||||
|
TrainingSession ts = new TrainingSession(gradInstance);
|
||||||
|
gradInstance.setTrainingConfig(trainingConfig); //In case any listeners want to use it
|
||||||
|
|
||||||
|
Set<String> paramsToTrain = new LinkedHashSet<>();
|
||||||
|
for(Variable v : variables.values()){
|
||||||
|
if(v.getVariable().getVariableType() == VariableType.VARIABLE){
|
||||||
|
//TODO not all variable type are needed - i.e., variable that doesn't impact loss should be skipped
|
||||||
|
paramsToTrain.add(v.getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Loss lastLoss = null;
|
||||||
|
for (int i = 0; i < numEpochs; i++) {
|
||||||
if (incrementEpochCount && hasListeners) {
|
if (incrementEpochCount && hasListeners) {
|
||||||
at.setEpoch(trainingConfig.getEpochCount());
|
at.setEpoch(trainingConfig.getEpochCount());
|
||||||
for (Listener l : activeListeners) {
|
for (Listener l : activeListeners) {
|
||||||
|
@ -2200,153 +2163,38 @@ public class SameDiff extends SDBaseOps {
|
||||||
Preconditions.checkState(placeholders.size() > 0, "No placeholder variables were set for training");
|
Preconditions.checkState(placeholders.size() > 0, "No placeholder variables were set for training");
|
||||||
resolveVariablesWith(placeholders);
|
resolveVariablesWith(placeholders);
|
||||||
|
|
||||||
//Calculate gradients:
|
//Call TrainingSession to perform training
|
||||||
execBackwards(placeholders, at.operation(), ds, requiredVars, activeListeners);
|
|
||||||
|
|
||||||
|
|
||||||
//Apply updater:
|
|
||||||
if (!initializedTraining)
|
if (!initializedTraining)
|
||||||
initializeTraining();
|
initializeTraining();
|
||||||
|
|
||||||
Map<Class<?>, AtomicDouble> regScore = null; //Holds regularization scores for later reporting to listeners
|
lastLoss = ts.trainingIteration(
|
||||||
if (hasListeners) {
|
trainingConfig,
|
||||||
regScore = new HashMap<>();
|
placeholders,
|
||||||
}
|
paramsToTrain,
|
||||||
|
updaterMap,
|
||||||
|
ds,
|
||||||
|
getLossVariables(),
|
||||||
|
listenersWitHistory,
|
||||||
|
at);
|
||||||
|
|
||||||
int iteration = trainingConfig.getIterationCount();
|
|
||||||
int e = trainingConfig.getEpochCount();
|
|
||||||
for (Variable v : variables.values()) {
|
|
||||||
//Only update trainable params - float type parameters (variable type vars)
|
|
||||||
SDVariable sdv = v.getVariable();
|
|
||||||
if (sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
|
|
||||||
INDArray param = sdv.getArr();
|
|
||||||
SDVariable gradVar = sdv.getGradient();
|
|
||||||
if (gradVar == null) {
|
|
||||||
//Not all trainable parameters have gradients defined.
|
|
||||||
//Consider graph: in1->loss1; in2->loss2, where we optimize only loss1.
|
|
||||||
//No gradient will be present for in2, because in2 doesn't impact loss1 at all
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
INDArray grad = gradVar.getArr();
|
|
||||||
//Note: don't need to divide by minibatch - that should be handled in loss function and hence loss function gradients,
|
|
||||||
// which should flow through to here
|
|
||||||
|
|
||||||
//Pre-apply regularization (L1, L2)
|
|
||||||
List<Regularization> r = trainingConfig.getRegularization();
|
|
||||||
int iterCount = trainingConfig.getIterationCount();
|
|
||||||
int epochCount = trainingConfig.getEpochCount();
|
|
||||||
double lr = trainingConfig.getUpdater().hasLearningRate() ? trainingConfig.getUpdater().getLearningRate(iteration, epochCount) : 1.0;
|
|
||||||
if (r != null && r.size() > 0) {
|
|
||||||
for (Regularization reg : r) {
|
|
||||||
if (reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) {
|
|
||||||
reg.apply(param, grad, lr, iterCount, epochCount);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//Apply updater. Note that we need to reshape to [1,length] for updater
|
|
||||||
INDArray reshapedView = Shape.newShapeNoCopy(grad, new long[]{1, grad.length()}, grad.ordering() == 'f'); //TODO make sure we always reshape in same order!
|
|
||||||
Preconditions.checkState(reshapedView != null, "Error reshaping array for parameter \"%s\": array is a view?", sdv);
|
|
||||||
GradientUpdater u = updaterMap.get(sdv.getVarName());
|
|
||||||
try {
|
|
||||||
u.applyUpdater(reshapedView, iteration, e);
|
|
||||||
} catch (Throwable t) {
|
|
||||||
throw new RuntimeException("Error applying updater " + u.getClass().getSimpleName() + " to parameter \"" + sdv.getVarName()
|
|
||||||
+ "\": either parameter size is inconsistent between iterations, or \"" + sdv.getVarName() + "\" should not be a trainable parameter?", t);
|
|
||||||
}
|
|
||||||
|
|
||||||
//Post-apply regularization (weight decay)
|
|
||||||
if (r != null && r.size() > 0) {
|
|
||||||
for (Regularization reg : r) {
|
|
||||||
if (reg.applyStep() == Regularization.ApplyStep.POST_UPDATER) {
|
|
||||||
reg.apply(param, grad, lr, iterCount, epochCount);
|
|
||||||
if (hasListeners) {
|
|
||||||
double score = reg.score(param, iterCount, epochCount);
|
|
||||||
if (!regScore.containsKey(reg.getClass())) {
|
|
||||||
regScore.put(reg.getClass(), new AtomicDouble());
|
|
||||||
}
|
|
||||||
regScore.get(reg.getClass()).addAndGet(score);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (hasListeners) {
|
|
||||||
for (Listener l : activeListeners) {
|
|
||||||
if (l.isActive(at.operation()))
|
|
||||||
l.preUpdate(this, at, v, reshapedView);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (trainingConfig.isMinimize()) {
|
|
||||||
param.subi(grad);
|
|
||||||
} else {
|
|
||||||
param.addi(grad);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
double[] d = new double[lossVariables.size() + regScore.size()];
|
|
||||||
List<String> lossVars;
|
|
||||||
if (regScore.size() > 0) {
|
|
||||||
lossVars = new ArrayList<>(lossVariables.size() + regScore.size());
|
|
||||||
lossVars.addAll(lossVariables);
|
|
||||||
int s = regScore.size();
|
|
||||||
//Collect regularization losses
|
|
||||||
for (Map.Entry<Class<?>, AtomicDouble> entry : regScore.entrySet()) {
|
|
||||||
lossVars.add(entry.getKey().getSimpleName());
|
|
||||||
d[s] = entry.getValue().get();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
lossVars = lossVariables;
|
|
||||||
}
|
|
||||||
|
|
||||||
//Collect the losses...
|
|
||||||
SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY);
|
|
||||||
int count = 0;
|
|
||||||
for (String s : lossVariables) {
|
|
||||||
INDArray arr = gradFn.getArrForVarName(s);
|
|
||||||
double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue();
|
|
||||||
d[count++] = l;
|
|
||||||
}
|
|
||||||
|
|
||||||
Loss loss = new Loss(lossVars, d);
|
|
||||||
|
|
||||||
if (lossNames == null) {
|
|
||||||
lossNames = lossVars;
|
|
||||||
} else {
|
|
||||||
Preconditions.checkState(lossNames.equals(lossVars),
|
|
||||||
"Loss names mismatch, expected: %s, got: %s", lossNames, lossVars);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (lossSums == null) {
|
if (lossSums == null) {
|
||||||
lossSums = d;
|
lossSums = lastLoss.getLosses().clone();
|
||||||
} else {
|
} else {
|
||||||
Preconditions.checkState(lossNames.equals(lossVars),
|
|
||||||
"Loss size mismatch, expected: %s, got: %s", lossSums.length, d.length);
|
|
||||||
|
|
||||||
for (int j = 0; j < lossSums.length; j++) {
|
for (int j = 0; j < lossSums.length; j++) {
|
||||||
lossSums[j] += d[j];
|
lossSums[j] += lastLoss.getLosses()[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
lossCount++;
|
lossCount++;
|
||||||
|
|
||||||
if (hasListeners) {
|
|
||||||
for (Listener l : activeListeners) {
|
|
||||||
l.iterationDone(this, at, ds, loss);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
trainingConfig.incrementIterationCount();
|
trainingConfig.incrementIterationCount();
|
||||||
}
|
}
|
||||||
|
|
||||||
long epochTime = System.currentTimeMillis() - epochStartTime;
|
long epochTime = System.currentTimeMillis() - epochStartTime;
|
||||||
|
|
||||||
if (incrementEpochCount) {
|
if (incrementEpochCount) {
|
||||||
|
lossNames = lastLoss.getLossNames();
|
||||||
|
|
||||||
for (int j = 0; j < lossSums.length; j++)
|
for (int j = 0; j < lossSums.length; j++)
|
||||||
lossSums[j] /= lossCount;
|
lossSums[j] /= lossCount;
|
||||||
|
|
||||||
|
@ -2356,14 +2204,13 @@ public class SameDiff extends SDBaseOps {
|
||||||
lossCurve = new LossCurve(lossSums, lossNames);
|
lossCurve = new LossCurve(lossSums, lossNames);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (incrementEpochCount) {
|
if (incrementEpochCount) {
|
||||||
if (hasListeners) {
|
if (hasListeners) {
|
||||||
|
|
||||||
boolean doStop = false;
|
boolean doStop = false;
|
||||||
Listener stopped = null;
|
Listener stopped = null;
|
||||||
|
|
||||||
for (Listener l : activeListeners) {
|
for (Listener l : activeListeners) {
|
||||||
|
|
||||||
ListenerResponse res = l.epochEnd(this, at, lossCurve, epochTime);
|
ListenerResponse res = l.epochEnd(this, at, lossCurve, epochTime);
|
||||||
|
|
||||||
if (res == ListenerResponse.STOP && (i < numEpochs - 1)) {
|
if (res == ListenerResponse.STOP && (i < numEpochs - 1)) {
|
||||||
|
@ -2431,7 +2278,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
trainingConfig.incrementEpochCount();
|
trainingConfig.incrementEpochCount();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i < numEpochs - 1) {
|
if (i < numEpochs - 1) {
|
||||||
iter.reset();
|
iter.reset();
|
||||||
}
|
}
|
||||||
|
@ -2507,7 +2353,9 @@ public class SameDiff extends SDBaseOps {
|
||||||
INDArray arr = v.getVariable().getArr();
|
INDArray arr = v.getVariable().getArr();
|
||||||
long stateSize = trainingConfig.getUpdater().stateSize(arr.length());
|
long stateSize = trainingConfig.getUpdater().stateSize(arr.length());
|
||||||
INDArray view = stateSize == 0 ? null : Nd4j.createUninitialized(arr.dataType(), 1, stateSize);
|
INDArray view = stateSize == 0 ? null : Nd4j.createUninitialized(arr.dataType(), 1, stateSize);
|
||||||
updaterMap.put(v.getName(), trainingConfig.getUpdater().instantiate(view, true));
|
GradientUpdater gu = trainingConfig.getUpdater().instantiate(view, false);
|
||||||
|
gu.setStateViewArray(view, arr.shape(), arr.ordering(), true);
|
||||||
|
updaterMap.put(v.getName(), gu);
|
||||||
}
|
}
|
||||||
|
|
||||||
initializedTraining = true;
|
initializedTraining = true;
|
||||||
|
@ -3862,7 +3710,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
long thisSize = trainingConfig.getUpdater().stateSize(arr.length());
|
long thisSize = trainingConfig.getUpdater().stateSize(arr.length());
|
||||||
if (thisSize > 0) {
|
if (thisSize > 0) {
|
||||||
INDArray stateArr = Nd4j.create(arr.dataType(), 1, thisSize);
|
INDArray stateArr = Nd4j.create(arr.dataType(), 1, thisSize);
|
||||||
GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, true);
|
GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, false);
|
||||||
|
u.setStateViewArray(stateArr, arr.shape(), arr.ordering(), true); //TODO eventually this should be 1 call...
|
||||||
updaterMap.put(v.getVarName(), u);
|
updaterMap.put(v.getVarName(), u);
|
||||||
} else {
|
} else {
|
||||||
GradientUpdater u = trainingConfig.getUpdater().instantiate((INDArray) null, true);
|
GradientUpdater u = trainingConfig.getUpdater().instantiate((INDArray) null, true);
|
||||||
|
@ -3946,7 +3795,53 @@ public class SameDiff extends SDBaseOps {
|
||||||
sessions.clear();
|
sessions.clear();
|
||||||
|
|
||||||
//Recalculate datatypes of outputs, and dynamically update them
|
//Recalculate datatypes of outputs, and dynamically update them
|
||||||
calculateOutputDataTypes(true);
|
Set<String> allSeenOps = new HashSet<>();
|
||||||
|
Queue<String> queueOps = new LinkedList<>();
|
||||||
|
|
||||||
|
for(String s : dataTypeMap.keySet()){
|
||||||
|
Variable v = variables.get(s);
|
||||||
|
v.getVariable().setDataType(dataTypeMap.get(s));
|
||||||
|
List<String> inToOp = v.getInputsForOp();
|
||||||
|
if(inToOp != null){
|
||||||
|
for(String op : inToOp) {
|
||||||
|
if (!allSeenOps.contains(op)) {
|
||||||
|
allSeenOps.add(op);
|
||||||
|
queueOps.add(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while(!queueOps.isEmpty()){
|
||||||
|
String op = queueOps.remove();
|
||||||
|
SameDiffOp o = ops.get(op);
|
||||||
|
List<String> inVars = o.getInputsToOp();
|
||||||
|
List<DataType> inDTypes = new ArrayList<>();
|
||||||
|
if(inVars != null) {
|
||||||
|
for (String s : inVars) {
|
||||||
|
SDVariable v = variables.get(s).getVariable();
|
||||||
|
inDTypes.add(v.dataType());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
List<DataType> outDtypes = o.getOp().calculateOutputDataTypes(inDTypes);
|
||||||
|
List<String> outVars = o.getOutputsOfOp();
|
||||||
|
for( int i=0; i<outVars.size(); i++ ){
|
||||||
|
String varName = outVars.get(i);
|
||||||
|
Variable var = variables.get(varName);
|
||||||
|
SDVariable v = var.getVariable();
|
||||||
|
v.setDataType(outDtypes.get(i));
|
||||||
|
|
||||||
|
//Also update queue
|
||||||
|
if(var.getInputsForOp() != null){
|
||||||
|
for(String opName : var.getInputsForOp()){
|
||||||
|
if(!allSeenOps.contains(opName)){
|
||||||
|
allSeenOps.add(opName);
|
||||||
|
queueOps.add(opName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4097,6 +3992,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
variables.get(varName).getInputsForOp().remove(function.getOwnName());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -4476,11 +4373,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
else if (function instanceof BaseOp) {
|
else if (function instanceof BaseOp) {
|
||||||
SDVariable[] ret = new SDVariable[1];
|
SDVariable[] ret = new SDVariable[1];
|
||||||
SDVariable checkGet = getVariable(baseName);
|
SDVariable checkGet = getVariable(baseName);
|
||||||
char ordering = 'c';
|
|
||||||
SDVariable[] args = function.args();
|
SDVariable[] args = function.args();
|
||||||
if (args != null && args.length > 0 && function.args()[0].getArr() != null) { //Args may be null or length 0 for some ops, like eye
|
|
||||||
ordering = function.args()[0].getArr().ordering();
|
|
||||||
}
|
|
||||||
if (checkGet == null) {
|
if (checkGet == null) {
|
||||||
//Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme
|
//Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme
|
||||||
org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0);
|
org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0);
|
||||||
|
@ -4530,45 +4423,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
return sameDiffFunctionInstances.get(functionName);
|
return sameDiffFunctionInstances.get(functionName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link SDBaseOps#whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public While whileStatement(SameDiffConditional sameDiffConditional,
|
|
||||||
SameDiffFunctionDefinition conditionBody,
|
|
||||||
SameDiffFunctionDefinition loopBody
|
|
||||||
, SDVariable[] inputVars) {
|
|
||||||
return While.builder()
|
|
||||||
.inputVars(inputVars)
|
|
||||||
.condition(conditionBody)
|
|
||||||
.predicate(sameDiffConditional)
|
|
||||||
.trueBody(loopBody)
|
|
||||||
.parent(this)
|
|
||||||
.blockName("while-" + UUID.randomUUID().toString())
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link SDBaseOps#ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public If ifStatement(SameDiffConditional conditional,
|
|
||||||
SameDiffFunctionDefinition conditionBody,
|
|
||||||
SameDiffFunctionDefinition trueBody,
|
|
||||||
SameDiffFunctionDefinition falseBody
|
|
||||||
, SDVariable[] inputVars) {
|
|
||||||
return If.builder()
|
|
||||||
.conditionBody(conditionBody)
|
|
||||||
.falseBody(falseBody)
|
|
||||||
.trueBody(trueBody)
|
|
||||||
.predicate(conditional)
|
|
||||||
.inputVars(inputVars)
|
|
||||||
.parent(this)
|
|
||||||
.blockName("if-" + UUID.randomUUID().toString())
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new TensorArray.
|
* Create a new TensorArray.
|
||||||
*/
|
*/
|
||||||
|
@ -4648,6 +4502,51 @@ public class SameDiff extends SDBaseOps {
|
||||||
return execSingle(placeholders, outputs.get(0));
|
return execSingle(placeholders, outputs.get(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #calculateGradients(Map, Collection)}
|
||||||
|
*/
|
||||||
|
public Map<String, INDArray> calculateGradients(Map<String, INDArray> placeholderVals, @NonNull String... variables) {
|
||||||
|
Preconditions.checkArgument(variables.length > 0, "No variables were specified");
|
||||||
|
return calculateGradients(placeholderVals, Arrays.asList(variables));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate and return the gradients for the specified variables
|
||||||
|
*
|
||||||
|
* @param placeholderVals Placeholders. May be null
|
||||||
|
* @param variables Names of the variables that you want the gradient arrays for
|
||||||
|
* @return Gradients as a map, keyed by the variable name
|
||||||
|
*/
|
||||||
|
public Map<String, INDArray> calculateGradients(Map<String, INDArray> placeholderVals, @NonNull Collection<String> variables) {
|
||||||
|
Preconditions.checkArgument(!variables.isEmpty(), "No variables were specified");
|
||||||
|
if (getFunction(GRAD_FN_KEY) == null) {
|
||||||
|
createGradFunction();
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> gradVarNames = new ArrayList<>(variables.size());
|
||||||
|
for (String s : variables) {
|
||||||
|
Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s);
|
||||||
|
SDVariable v = getVariable(s).getGradient();
|
||||||
|
if (v != null) {
|
||||||
|
//In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable
|
||||||
|
gradVarNames.add(v.getVarName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Key is gradient variable name
|
||||||
|
Map<String, INDArray> grads = getFunction(GRAD_FN_KEY).output(placeholderVals, gradVarNames);
|
||||||
|
|
||||||
|
Map<String, INDArray> out = new HashMap<>();
|
||||||
|
for (String s : variables) {
|
||||||
|
if (getVariable(s).getGradient() != null) {
|
||||||
|
String gradVar = getVariable(s).getGradient().getVarName();
|
||||||
|
out.put(s, grads.get(gradVar));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create (if required) and then calculate the variable gradients (backward pass) for this graph.<br>
|
* Create (if required) and then calculate the variable gradients (backward pass) for this graph.<br>
|
||||||
* After execution, the gradient arrays can be accessed using {@code myVariable.getGradient().getArr()}<br>
|
* After execution, the gradient arrays can be accessed using {@code myVariable.getGradient().getArr()}<br>
|
||||||
|
@ -4660,6 +4559,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
*
|
*
|
||||||
* @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map
|
* @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public void execBackwards(Map<String, INDArray> placeholders, Operation op) {
|
public void execBackwards(Map<String, INDArray> placeholders, Operation op) {
|
||||||
execBackwards(placeholders, op, null, Collections.<String>emptyList(), Collections.<Listener>emptyList());
|
execBackwards(placeholders, op, null, Collections.<String>emptyList(), Collections.<Listener>emptyList());
|
||||||
}
|
}
|
||||||
|
@ -4669,10 +4569,12 @@ public class SameDiff extends SDBaseOps {
|
||||||
* <p>
|
* <p>
|
||||||
* Uses {@link Operation#INFERENCE}.
|
* Uses {@link Operation#INFERENCE}.
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public void execBackwards(Map<String, INDArray> placeholders) {
|
public void execBackwards(Map<String, INDArray> placeholders) {
|
||||||
execBackwards(placeholders, Operation.INFERENCE);
|
execBackwards(placeholders, Operation.INFERENCE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
protected void execBackwards(Map<String, INDArray> placeholders, Operation op, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners) {
|
protected void execBackwards(Map<String, INDArray> placeholders, Operation op, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners) {
|
||||||
if (getFunction(GRAD_FN_KEY) == null) {
|
if (getFunction(GRAD_FN_KEY) == null) {
|
||||||
createGradFunction();
|
createGradFunction();
|
||||||
|
@ -4709,6 +4611,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
/**
|
/**
|
||||||
* See {@link #execBackwards(Map, List, Operation)}
|
* See {@link #execBackwards(Map, List, Operation)}
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, Operation op, String... variableGradNamesList) {
|
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, Operation op, String... variableGradNamesList) {
|
||||||
return execBackwards(placeholders, Arrays.asList(variableGradNamesList), op, null, Collections.<String>emptyList(), Collections.<Listener>emptyList());
|
return execBackwards(placeholders, Arrays.asList(variableGradNamesList), op, null, Collections.<String>emptyList(), Collections.<Listener>emptyList());
|
||||||
}
|
}
|
||||||
|
@ -4718,6 +4621,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
* <p>
|
* <p>
|
||||||
* Uses {@link Operation#INFERENCE}.
|
* Uses {@link Operation#INFERENCE}.
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, String... variableGradNamesList) {
|
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, String... variableGradNamesList) {
|
||||||
return execBackwards(placeholders, Operation.INFERENCE, variableGradNamesList);
|
return execBackwards(placeholders, Operation.INFERENCE, variableGradNamesList);
|
||||||
}
|
}
|
||||||
|
@ -4730,6 +4634,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map
|
* @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map
|
||||||
* @param variableGradNamesList Names of the gradient variables to calculate
|
* @param variableGradNamesList Names of the gradient variables to calculate
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList, Operation operation) {
|
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList, Operation operation) {
|
||||||
return execBackwards(placeholders, variableGradNamesList, operation, null, Collections.<String>emptyList(), Collections.<Listener>emptyList());
|
return execBackwards(placeholders, variableGradNamesList, operation, null, Collections.<String>emptyList(), Collections.<Listener>emptyList());
|
||||||
}
|
}
|
||||||
|
@ -4739,10 +4644,12 @@ public class SameDiff extends SDBaseOps {
|
||||||
* <p>
|
* <p>
|
||||||
* Uses {@link Operation#INFERENCE}.
|
* Uses {@link Operation#INFERENCE}.
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList) {
|
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList) {
|
||||||
return execBackwards(placeholders, variableGradNamesList, Operation.INFERENCE);
|
return execBackwards(placeholders, variableGradNamesList, Operation.INFERENCE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
protected Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList, Operation operation,
|
protected Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList, Operation operation,
|
||||||
MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners) {
|
MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners) {
|
||||||
if (getFunction(GRAD_FN_KEY) == null) {
|
if (getFunction(GRAD_FN_KEY) == null) {
|
||||||
|
@ -5462,7 +5369,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
-1,
|
-1,
|
||||||
0, 0, 0, 0, 0, 0);
|
0, 0, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
|
||||||
return flatNode;
|
return flatNode;
|
||||||
}
|
}
|
||||||
|
@ -5538,7 +5445,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
val idxForOps = new IdentityHashMap<DifferentialFunction, Integer>();
|
val idxForOps = new IdentityHashMap<DifferentialFunction, Integer>();
|
||||||
List<SDVariable> allVars = variables();
|
List<SDVariable> allVars = variables();
|
||||||
for (SDVariable variable : allVars) {
|
for (SDVariable variable : allVars) {
|
||||||
INDArray arr = variable.getArr();
|
INDArray arr = variable.getVariableType() == VariableType.ARRAY ? null : variable.getArr();
|
||||||
log.trace("Exporting variable: [{}]", variable.getVarName());
|
log.trace("Exporting variable: [{}]", variable.getVarName());
|
||||||
|
|
||||||
//If variable is the output of some op - let's use the ONE index for exporting, and properly track the output
|
//If variable is the output of some op - let's use the ONE index for exporting, and properly track the output
|
||||||
|
@ -5582,7 +5489,26 @@ public class SameDiff extends SDBaseOps {
|
||||||
shape = FlatVariable.createShapeVector(bufferBuilder, shp);
|
shape = FlatVariable.createShapeVector(bufferBuilder, shp);
|
||||||
}
|
}
|
||||||
|
|
||||||
int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape, array, -1, varType);
|
int controlDeps = 0;
|
||||||
|
int controlDepsForOp = 0;
|
||||||
|
int controlDepsForVar = 0;
|
||||||
|
Variable v = variables.get(varName);
|
||||||
|
|
||||||
|
int[] cds = FlatBuffersMapper.mapOrNull(v.getControlDeps(), bufferBuilder);
|
||||||
|
if(cds != null)
|
||||||
|
controlDeps = FlatVariable.createControlDepsVector(bufferBuilder, cds);
|
||||||
|
|
||||||
|
int[] cdsForOp = FlatBuffersMapper.mapOrNull(v.getControlDepsForOp(), bufferBuilder);
|
||||||
|
if(cdsForOp != null)
|
||||||
|
controlDepsForOp = FlatVariable.createControlDepForOpVector(bufferBuilder, cdsForOp);
|
||||||
|
|
||||||
|
int[] cdsForVar = FlatBuffersMapper.mapOrNull(v.getControlDepsForVar(), bufferBuilder);
|
||||||
|
if(cdsForVar != null)
|
||||||
|
controlDepsForVar = FlatVariable.createControlDepsForVarVector(bufferBuilder, cdsForVar);
|
||||||
|
|
||||||
|
|
||||||
|
int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape,
|
||||||
|
array, -1, varType, controlDeps, controlDepsForOp, controlDepsForVar);
|
||||||
flatVariables.add(flatVariable);
|
flatVariables.add(flatVariable);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5593,43 +5519,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, variableList, reverseMap, forwardMap, framesMap, idCounter, fnId));
|
flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, variableList, reverseMap, forwardMap, framesMap, idCounter, fnId));
|
||||||
}
|
}
|
||||||
|
|
||||||
// we're dumping scopes now
|
|
||||||
for (Map.Entry<String, SameDiff> scope : sameDiffFunctionInstances.entrySet()) {
|
|
||||||
if (scope.getKey().equalsIgnoreCase(GRAD_FN_KEY)) {
|
|
||||||
//Skip the gradient function for export
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
flatNodes.add(asFlatNode(scope.getKey(), scope.getValue(), bufferBuilder));
|
|
||||||
val currVarList = new ArrayList<SDVariable>(scope.getValue().variables());
|
|
||||||
// converting all ops from node
|
|
||||||
for (val node : scope.getValue().variables()) {
|
|
||||||
INDArray arr = node.getArr();
|
|
||||||
if (arr == null) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
int name = bufferBuilder.createString(node.getVarName());
|
|
||||||
int array = arr.toFlatArray(bufferBuilder);
|
|
||||||
int id = IntPair.createIntPair(bufferBuilder, ++idx, 0);
|
|
||||||
|
|
||||||
val pair = parseVariable(node.getVarName());
|
|
||||||
reverseMap.put(pair.getFirst(), idx);
|
|
||||||
|
|
||||||
log.trace("Adding [{}] as [{}]", pair.getFirst(), idx);
|
|
||||||
|
|
||||||
byte varType = (byte) node.getVariableType().ordinal();
|
|
||||||
int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(arr.dataType()), 0, array, -1, varType);
|
|
||||||
flatVariables.add(flatVariable);
|
|
||||||
}
|
|
||||||
|
|
||||||
//add functions
|
|
||||||
for (SameDiffOp op : scope.getValue().ops.values()) {
|
|
||||||
DifferentialFunction func = op.getOp();
|
|
||||||
flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, currVarList, reverseMap, forwardMap, framesMap, idCounter, null));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int outputsOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatOffsets));
|
int outputsOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatOffsets));
|
||||||
int variablesOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatVariables));
|
int variablesOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatVariables));
|
||||||
int nodesOffset = FlatGraph.createNodesVector(bufferBuilder, Ints.toArray(flatNodes));
|
int nodesOffset = FlatGraph.createNodesVector(bufferBuilder, Ints.toArray(flatNodes));
|
||||||
|
@ -5958,7 +5847,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
vars.add(fg.variables(i));
|
vars.add(fg.variables(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
FlatConfiguration conf = fg.configuration();
|
// FlatConfiguration conf = fg.configuration();
|
||||||
|
|
||||||
/* Reconstruct the graph
|
/* Reconstruct the graph
|
||||||
We'll do the reconstruction manually here, rather than using sd.var(...), so that we have more control
|
We'll do the reconstruction manually here, rather than using sd.var(...), so that we have more control
|
||||||
|
@ -5995,6 +5884,35 @@ public class SameDiff extends SDBaseOps {
|
||||||
SDVariable var = new SDVariable(n, vt, sd, shape, dtype, null);
|
SDVariable var = new SDVariable(n, vt, sd, shape, dtype, null);
|
||||||
sd.variables.put(n, Variable.builder().name(n).variable(var).build());
|
sd.variables.put(n, Variable.builder().name(n).variable(var).build());
|
||||||
sd.variableNameToShape.put(n, shape);
|
sd.variableNameToShape.put(n, shape);
|
||||||
|
Variable v2 = sd.variables.get(n);
|
||||||
|
|
||||||
|
//Reconstruct control dependencies
|
||||||
|
if(v.controlDepsLength() > 0){
|
||||||
|
int num = v.controlDepsLength();
|
||||||
|
List<String> l = new ArrayList<>(num);
|
||||||
|
for( int i=0; i<num; i++ ){
|
||||||
|
l.add(v.controlDeps(i));
|
||||||
|
}
|
||||||
|
v2.setControlDeps(l);
|
||||||
|
}
|
||||||
|
if(v.controlDepForOpLength() > 0){
|
||||||
|
int num = v.controlDepForOpLength();
|
||||||
|
List<String> l = new ArrayList<>(num);
|
||||||
|
for( int i=0; i<num; i++ ){
|
||||||
|
l.add(v.controlDepForOp(i));
|
||||||
|
}
|
||||||
|
v2.setControlDepsForOp(l);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(v.controlDepsForVarLength() > 0){
|
||||||
|
int num = v.controlDepsForVarLength();
|
||||||
|
List<String> l = new ArrayList<>(num);
|
||||||
|
for( int i=0; i<num; i++ ){
|
||||||
|
l.add(v.controlDepsForVar(i));
|
||||||
|
}
|
||||||
|
v2.setControlDepsForVar(l);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FlatArray fa = v.ndarray();
|
FlatArray fa = v.ndarray();
|
||||||
|
@ -6063,7 +5981,37 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
inputNames[i] = varIn.getVarName();
|
inputNames[i] = varIn.getVarName();
|
||||||
}
|
}
|
||||||
sd.ops.get(df.getOwnName()).setInputsToOp(Arrays.asList(inputNames));
|
SameDiffOp op = sd.ops.get(df.getOwnName());
|
||||||
|
op.setInputsToOp(Arrays.asList(inputNames));
|
||||||
|
|
||||||
|
//Reconstruct control dependencies
|
||||||
|
if (fn.controlDepsLength() > 0) {
|
||||||
|
int l = fn.controlDepsLength();
|
||||||
|
List<String> list = new ArrayList<>(l);
|
||||||
|
for( int i=0; i<l; i++ ){
|
||||||
|
list.add(fn.controlDeps(i));
|
||||||
|
}
|
||||||
|
op.setControlDeps(list);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fn.varControlDepsLength() > 0) {
|
||||||
|
int l = fn.varControlDepsLength();
|
||||||
|
List<String> list = new ArrayList<>(l);
|
||||||
|
for( int i=0; i<l; i++ ){
|
||||||
|
list.add(fn.varControlDeps(i));
|
||||||
|
}
|
||||||
|
op.setVarControlDeps(list);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fn.controlDepForLength() > 0) {
|
||||||
|
int l = fn.controlDepForLength();
|
||||||
|
List<String> list = new ArrayList<>(l);
|
||||||
|
for( int i=0; i<l; i++ ){
|
||||||
|
list.add(fn.controlDepFor(i));
|
||||||
|
}
|
||||||
|
op.setControlDepFor(list);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//Record that input variables are input to this op
|
//Record that input variables are input to this op
|
||||||
for (String inName : inputNames) {
|
for (String inName : inputNames) {
|
||||||
|
@ -6072,9 +6020,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
v.setInputsForOp(new ArrayList<String>());
|
v.setInputsForOp(new ArrayList<String>());
|
||||||
}
|
}
|
||||||
if (!v.getInputsForOp().contains(df.getOwnName())) {
|
if (!v.getInputsForOp().contains(df.getOwnName())) {
|
||||||
v.getInputsForOp(
|
v.getInputsForOp().add(df.getOwnName());
|
||||||
|
|
||||||
).add(df.getOwnName());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6414,32 +6360,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
return sb.toString();
|
return sb.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate data types for the variables in the graph
|
|
||||||
*/
|
|
||||||
public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes() {
|
|
||||||
return calculateOutputDataTypes(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate data types for the variables in the graph
|
|
||||||
*/
|
|
||||||
public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes(boolean dynamicUpdate) {
|
|
||||||
List<String> allVars = new ArrayList<>(variables.keySet());
|
|
||||||
DataTypesSession session = new DataTypesSession(this, dynamicUpdate);
|
|
||||||
Map<String, org.nd4j.linalg.api.buffer.DataType> phValues = new HashMap<>();
|
|
||||||
for (Variable v : variables.values()) {
|
|
||||||
if (v.getVariable().isPlaceHolder()) {
|
|
||||||
org.nd4j.linalg.api.buffer.DataType dt = v.getVariable().dataType();
|
|
||||||
Preconditions.checkNotNull(dt, "Placeholder variable %s has null datatype", v.getName());
|
|
||||||
phValues.put(v.getName(), dt);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Map<String, org.nd4j.linalg.api.buffer.DataType> out = session.output(allVars, phValues, null,
|
|
||||||
Collections.<String>emptyList(), Collections.<Listener>emptyList(), At.defaultAt(Operation.INFERENCE));
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For internal use only.
|
* For internal use only.
|
||||||
* Creates a new discinct block name from baseName.
|
* Creates a new discinct block name from baseName.
|
||||||
|
@ -6470,14 +6390,14 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @return The imported graph
|
* @return The imported graph
|
||||||
*/
|
*/
|
||||||
public static SameDiff importFrozenTF(File graphFile) {
|
public static SameDiff importFrozenTF(File graphFile) {
|
||||||
return TFGraphMapper.getInstance().importGraph(graphFile);
|
return TFGraphMapper.importGraph(graphFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link #importFrozenTF(File)}
|
* See {@link #importFrozenTF(File)}
|
||||||
*/
|
*/
|
||||||
public static SameDiff importFrozenTF(GraphDef graphDef) {
|
public static SameDiff importFrozenTF(GraphDef graphDef) {
|
||||||
return TFGraphMapper.getInstance().importGraph(graphDef);
|
return TFGraphMapper.importGraph(graphDef);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -6487,7 +6407,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
* Again, the input can be text or binary.
|
* Again, the input can be text or binary.
|
||||||
*/
|
*/
|
||||||
public static SameDiff importFrozenTF(InputStream graph) {
|
public static SameDiff importFrozenTF(InputStream graph) {
|
||||||
return TFGraphMapper.getInstance().importGraph(graph);
|
return TFGraphMapper.importGraph(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -6511,7 +6431,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
int start = 1;
|
int start = 1;
|
||||||
|
|
||||||
// if we already have a name like "op_2", start from trying "op_3"
|
// if we already have a name like "op_2", start from trying "op_3"
|
||||||
if (base.contains("_")) {
|
if (base.contains("_") && base.matches(".*_\\d+")) {
|
||||||
// extract number used to generate base
|
// extract number used to generate base
|
||||||
Matcher num = Pattern.compile("(.*)_(\\d+)").matcher(base);
|
Matcher num = Pattern.compile("(.*)_(\\d+)").matcher(base);
|
||||||
// extract argIndex used to generate base
|
// extract argIndex used to generate base
|
||||||
|
|
|
@ -0,0 +1,444 @@
|
||||||
|
package org.nd4j.autodiff.samediff.internal;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.function.Predicate;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Object dependency tracker.
|
||||||
|
* <br>
|
||||||
|
* Dependency are denoted by: X -> Y, which means "Y depends on X"<br>
|
||||||
|
* In this implementation:<br>
|
||||||
|
* - Dependencies may be satisfied, or not satisfied<br>
|
||||||
|
* - The implementation tracks when the dependency for an object Y are fully satisfied. This occurs when:<br>
|
||||||
|
* 1. No dependencies X->Y exist<br>
|
||||||
|
* 2. All dependencies of the form X->Y have been marked as satisfied, via markSatisfied(x)<br>
|
||||||
|
* - When a dependency is satisfied, any dependent (Ys) are checked to see if all their dependencies are satisfied<br>
|
||||||
|
* - If a dependent has all dependencies satisfied, it is added to the "new all satisfied" queue for processing,
|
||||||
|
* which can be accessed via {@link #hasNewAllSatisfied()}, {@link #getNewAllSatisfied()} and {@link #getNewAllSatisfiedList()}<br>
|
||||||
|
* <br>
|
||||||
|
* Note: Two types of dependencies exist<br>
|
||||||
|
* 1. Standard dependencies - i.e., "Y depends on X"<br>
|
||||||
|
* 2. "Or" dependencies - i.e., "Y depends on (A or B)".<br>
|
||||||
|
* For Or dependencies of the form "(A or B) -> Y", Y will be marked as "all dependencies satisfied" if either A or B is marked as satisfied.
|
||||||
|
*
|
||||||
|
* @param <T> For a dependency X -> Y, Y has type T
|
||||||
|
* @param <D> For a dependency X -> Y, X has type D
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public abstract class AbstractDependencyTracker<T, D> {
|
||||||
|
@Getter
|
||||||
|
private final Map<T, Set<D>> dependencies; //Key: the dependent. Value: all things that the key depends on
|
||||||
|
@Getter
|
||||||
|
private final Map<T, Set<Pair<D, D>>> orDependencies; //Key: the dependent. Value: the set of OR dependencies
|
||||||
|
private final Map<D, Set<T>> reverseDependencies = new HashMap<>(); //Key: the dependee. Value: The set of all dependents that depend on this value
|
||||||
|
private final Map<D, Set<T>> reverseOrDependencies = new HashMap<>();
|
||||||
|
private final Set<D> satisfiedDependencies = new HashSet<>(); //Mark the dependency as satisfied. If not in set: assumed to not be satisfied
|
||||||
|
|
||||||
|
private final Set<T> allSatisfied; //Set of all dependent values (Ys) that have all dependencies satisfied
|
||||||
|
private final Queue<T> allSatisfiedQueue = new LinkedList<>(); //Queue for *new* "all satisfied" values. Values are removed using the "new all satisfied" methods
|
||||||
|
|
||||||
|
|
||||||
|
protected AbstractDependencyTracker() {
|
||||||
|
dependencies = (Map<T, Set<D>>) newTMap();
|
||||||
|
orDependencies = (Map<T, Set<Pair<D, D>>>) newTMap();
|
||||||
|
allSatisfied = newTSet();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return A new map where the dependents (i.e., Y in "X -> Y") are the key
|
||||||
|
*/
|
||||||
|
protected abstract Map<T, ?> newTMap();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return A new set where the dependents (i.e., Y in "X -> Y") are the key
|
||||||
|
*/
|
||||||
|
protected abstract Set<T> newTSet();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return A String representation of the dependent object
|
||||||
|
*/
|
||||||
|
protected abstract String toStringT(T t);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return A String representation of the dependee object
|
||||||
|
*/
|
||||||
|
protected abstract String toStringD(D d);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear all internal state for the dependency tracker
|
||||||
|
*/
|
||||||
|
public void clear() {
|
||||||
|
dependencies.clear();
|
||||||
|
orDependencies.clear();
|
||||||
|
reverseDependencies.clear();
|
||||||
|
reverseOrDependencies.clear();
|
||||||
|
satisfiedDependencies.clear();
|
||||||
|
allSatisfied.clear();
|
||||||
|
allSatisfiedQueue.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return True if no dependencies have been defined
|
||||||
|
*/
|
||||||
|
public boolean isEmpty() {
|
||||||
|
return dependencies.isEmpty() && orDependencies.isEmpty() &&
|
||||||
|
allSatisfiedQueue.isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return True if the dependency has been marked as satisfied using {@link #markSatisfied(Object, boolean)}
|
||||||
|
*/
|
||||||
|
public boolean isSatisfied(@NonNull D x) {
|
||||||
|
return satisfiedDependencies.contains(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mark the specified value as satisfied.
|
||||||
|
* For example, if two dependencies have been previously added (X -> Y) and (X -> A) then after the markSatisfied(X, true)
|
||||||
|
* call, both of these dependencies are considered satisfied.
|
||||||
|
*
|
||||||
|
* @param x Value to mark
|
||||||
|
* @param satisfied Whether to mark as satisfied (true) or unsatisfied (false)
|
||||||
|
*/
|
||||||
|
public void markSatisfied(@NonNull D x, boolean satisfied) {
|
||||||
|
if (satisfied) {
|
||||||
|
boolean alreadySatisfied = satisfiedDependencies.contains(x);
|
||||||
|
|
||||||
|
if (!alreadySatisfied) {
|
||||||
|
satisfiedDependencies.add(x);
|
||||||
|
|
||||||
|
//Check if any Y's exist that have dependencies that are all satisfied, for X -> Y
|
||||||
|
Set<T> s = reverseDependencies.get(x);
|
||||||
|
Set<T> s2 = reverseOrDependencies.get(x);
|
||||||
|
|
||||||
|
Set<T> set;
|
||||||
|
if (s != null && s2 != null) {
|
||||||
|
set = newTSet();
|
||||||
|
set.addAll(s);
|
||||||
|
set.addAll(s2);
|
||||||
|
} else if (s != null) {
|
||||||
|
set = s;
|
||||||
|
} else if (s2 != null) {
|
||||||
|
set = s2;
|
||||||
|
} else {
|
||||||
|
if (log.isTraceEnabled()) {
|
||||||
|
log.trace("No values depend on: {}", toStringD(x));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (T t : set) {
|
||||||
|
Set<D> required = dependencies.get(t);
|
||||||
|
Set<Pair<D, D>> requiredOr = orDependencies.get(t);
|
||||||
|
boolean allSatisfied = true;
|
||||||
|
if (required != null) {
|
||||||
|
for (D d : required) {
|
||||||
|
if (!isSatisfied(d)) {
|
||||||
|
allSatisfied = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (allSatisfied && requiredOr != null) {
|
||||||
|
for (Pair<D, D> p : requiredOr) {
|
||||||
|
if (!isSatisfied(p.getFirst()) && !isSatisfied(p.getSecond())) {
|
||||||
|
allSatisfied = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (allSatisfied) {
|
||||||
|
if (!this.allSatisfied.contains(t)) {
|
||||||
|
this.allSatisfied.add(t);
|
||||||
|
this.allSatisfiedQueue.add(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
satisfiedDependencies.remove(x);
|
||||||
|
if (!allSatisfied.isEmpty()) {
|
||||||
|
|
||||||
|
Set<T> reverse = reverseDependencies.get(x);
|
||||||
|
if (reverse != null) {
|
||||||
|
for (T y : reverse) {
|
||||||
|
if (allSatisfied.contains(y)) {
|
||||||
|
allSatisfied.remove(y);
|
||||||
|
allSatisfiedQueue.remove(y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Set<T> orReverse = reverseOrDependencies.get(x);
|
||||||
|
if (orReverse != null) {
|
||||||
|
for (T y : orReverse) {
|
||||||
|
if (allSatisfied.contains(y) && !isAllSatisfied(y)) {
|
||||||
|
allSatisfied.remove(y);
|
||||||
|
allSatisfiedQueue.remove(y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check whether any dependencies x -> y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)}
|
||||||
|
* or {@link #addOrDependency(Object, Object, Object)}
|
||||||
|
*
|
||||||
|
* @param y Dependent to check
|
||||||
|
* @return True if Y depends on any values
|
||||||
|
*/
|
||||||
|
public boolean hasDependency(@NonNull T y) {
|
||||||
|
Set<D> s1 = dependencies.get(y);
|
||||||
|
if (s1 != null && !s1.isEmpty())
|
||||||
|
return true;
|
||||||
|
|
||||||
|
Set<Pair<D, D>> s2 = orDependencies.get(y);
|
||||||
|
return s2 != null && !s2.isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all dependencies x, for x -> y, and (x1 or x2) -> y
|
||||||
|
*
|
||||||
|
* @param y Dependent to get dependencies for
|
||||||
|
* @return List of dependencies
|
||||||
|
*/
|
||||||
|
public DependencyList<T, D> getDependencies(@NonNull T y) {
|
||||||
|
Set<D> s1 = dependencies.get(y);
|
||||||
|
Set<Pair<D, D>> s2 = orDependencies.get(y);
|
||||||
|
|
||||||
|
List<D> l1 = (s1 == null ? null : new ArrayList<>(s1));
|
||||||
|
List<Pair<D, D>> l2 = (s2 == null ? null : new ArrayList<>(s2));
|
||||||
|
|
||||||
|
return new DependencyList<>(y, l1, l2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a dependency: y depends on x, as in x -> y
|
||||||
|
*
|
||||||
|
* @param y The dependent
|
||||||
|
* @param x The dependee that is required for Y
|
||||||
|
*/
|
||||||
|
public void addDependency(@NonNull T y, @NonNull D x) {
|
||||||
|
if (!dependencies.containsKey(y))
|
||||||
|
dependencies.put(y, new HashSet<D>());
|
||||||
|
|
||||||
|
if (!reverseDependencies.containsKey(x))
|
||||||
|
reverseDependencies.put(x, newTSet());
|
||||||
|
|
||||||
|
dependencies.get(y).add(x);
|
||||||
|
reverseDependencies.get(x).add(y);
|
||||||
|
|
||||||
|
checkAndUpdateIfAllSatisfied(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void checkAndUpdateIfAllSatisfied(@NonNull T y) {
|
||||||
|
boolean allSat = isAllSatisfied(y);
|
||||||
|
if (allSat) {
|
||||||
|
//Case where "x is satisfied" happened before x->y added
|
||||||
|
if (!allSatisfied.contains(y)) {
|
||||||
|
allSatisfied.add(y);
|
||||||
|
allSatisfiedQueue.add(y);
|
||||||
|
}
|
||||||
|
} else if (allSatisfied.contains(y)) {
|
||||||
|
if (!allSatisfiedQueue.contains(y)) {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
sb.append("Dependent object \"").append(toStringT(y)).append("\" was previously processed after all dependencies")
|
||||||
|
.append(" were marked satisfied, but is now additional dependencies have been added.\n");
|
||||||
|
DependencyList<T, D> dl = getDependencies(y);
|
||||||
|
if (dl.getDependencies() != null) {
|
||||||
|
sb.append("Dependencies:\n");
|
||||||
|
for (D d : dl.getDependencies()) {
|
||||||
|
sb.append(d).append(" - ").append(isSatisfied(d) ? "Satisfied" : "Not satisfied").append("\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (dl.getOrDependencies() != null) {
|
||||||
|
sb.append("Or dependencies:\n");
|
||||||
|
for (Pair<D, D> p : dl.getOrDependencies()) {
|
||||||
|
sb.append(p).append(" - satisfied=(").append(isSatisfied(p.getFirst())).append(",").append(isSatisfied(p.getSecond())).append(")");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new IllegalStateException(sb.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
//Not satisfied, but is in the queue -> needs to be removed
|
||||||
|
allSatisfied.remove(y);
|
||||||
|
allSatisfiedQueue.remove(y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected boolean isAllSatisfied(@NonNull T y) {
|
||||||
|
Set<D> set1 = dependencies.get(y);
|
||||||
|
|
||||||
|
boolean allSatisfied = true;
|
||||||
|
if (set1 != null) {
|
||||||
|
for (D d : set1) {
|
||||||
|
allSatisfied = isSatisfied(d);
|
||||||
|
if (!allSatisfied)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (allSatisfied) {
|
||||||
|
Set<Pair<D, D>> set2 = orDependencies.get(y);
|
||||||
|
if (set2 != null) {
|
||||||
|
for (Pair<D, D> p : set2) {
|
||||||
|
allSatisfied = isSatisfied(p.getFirst()) || isSatisfied(p.getSecond());
|
||||||
|
if (!allSatisfied)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return allSatisfied;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Remove a dependency (x -> y)
|
||||||
|
*
|
||||||
|
* @param y The dependent that currently requires X
|
||||||
|
* @param x The dependee that is no longer required for Y
|
||||||
|
*/
|
||||||
|
public void removeDependency(@NonNull T y, @NonNull D x) {
|
||||||
|
if (!dependencies.containsKey(y) && !orDependencies.containsKey(y))
|
||||||
|
return;
|
||||||
|
|
||||||
|
Set<D> s = dependencies.get(y);
|
||||||
|
if (s != null) {
|
||||||
|
s.remove(x);
|
||||||
|
if (s.isEmpty())
|
||||||
|
dependencies.remove(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
Set<T> s2 = reverseDependencies.get(x);
|
||||||
|
if (s2 != null) {
|
||||||
|
s2.remove(y);
|
||||||
|
if (s2.isEmpty())
|
||||||
|
reverseDependencies.remove(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Set<Pair<D, D>> s3 = orDependencies.get(y);
|
||||||
|
if (s3 != null) {
|
||||||
|
boolean removedReverse = false;
|
||||||
|
Iterator<Pair<D, D>> iter = s3.iterator();
|
||||||
|
while (iter.hasNext()) {
|
||||||
|
Pair<D, D> p = iter.next();
|
||||||
|
if (x.equals(p.getFirst()) || x.equals(p.getSecond())) {
|
||||||
|
iter.remove();
|
||||||
|
|
||||||
|
if (!removedReverse) {
|
||||||
|
Set<T> set1 = reverseOrDependencies.get(p.getFirst());
|
||||||
|
Set<T> set2 = reverseOrDependencies.get(p.getSecond());
|
||||||
|
|
||||||
|
set1.remove(y);
|
||||||
|
set2.remove(y);
|
||||||
|
|
||||||
|
if (set1.isEmpty())
|
||||||
|
reverseOrDependencies.remove(p.getFirst());
|
||||||
|
if (set2.isEmpty())
|
||||||
|
reverseOrDependencies.remove(p.getSecond());
|
||||||
|
|
||||||
|
removedReverse = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (s3 != null && s3.isEmpty())
|
||||||
|
orDependencies.remove(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -> Y<br>
|
||||||
|
* If either x1 or x2 (or both) are marked satisfied via {@link #markSatisfied(Object, boolean)} then the
|
||||||
|
* dependency is considered satisfied
|
||||||
|
*
|
||||||
|
* @param y Dependent
|
||||||
|
* @param x1 Dependee 1
|
||||||
|
* @param x2 Dependee 2
|
||||||
|
*/
|
||||||
|
public void addOrDependency(@NonNull T y, @NonNull D x1, @NonNull D x2) {
|
||||||
|
if (!orDependencies.containsKey(y))
|
||||||
|
orDependencies.put(y, new HashSet<Pair<D, D>>());
|
||||||
|
|
||||||
|
if (!reverseOrDependencies.containsKey(x1))
|
||||||
|
reverseOrDependencies.put(x1, newTSet());
|
||||||
|
if (!reverseOrDependencies.containsKey(x2))
|
||||||
|
reverseOrDependencies.put(x2, newTSet());
|
||||||
|
|
||||||
|
orDependencies.get(y).add(new Pair<>(x1, x2));
|
||||||
|
reverseOrDependencies.get(x1).add(y);
|
||||||
|
reverseOrDependencies.get(x2).add(y);
|
||||||
|
|
||||||
|
checkAndUpdateIfAllSatisfied(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X->Y)
|
||||||
|
*/
|
||||||
|
public boolean hasNewAllSatisfied() {
|
||||||
|
return !allSatisfiedQueue.isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)}
|
||||||
|
* Throws an exception if {@link #hasNewAllSatisfied()} returns false.<br>
|
||||||
|
* Note that once a value has been retrieved from here, no new dependencies of the form (X -> Y) can be added for this value;
|
||||||
|
* the value is considered "processed" at this point.
|
||||||
|
*
|
||||||
|
* @return The next new "all satisfied dependent"
|
||||||
|
*/
|
||||||
|
public T getNewAllSatisfied() {
|
||||||
|
Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied");
|
||||||
|
return allSatisfiedQueue.remove();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return As per {@link #getNewAllSatisfied()} but returns all values
|
||||||
|
*/
|
||||||
|
public List<T> getNewAllSatisfiedList() {
|
||||||
|
Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied");
|
||||||
|
List<T> ret = new ArrayList<>(allSatisfiedQueue);
|
||||||
|
allSatisfiedQueue.clear();
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* As per {@link #getNewAllSatisfied()} but instead of returning the first dependee, it returns the first that matches
|
||||||
|
* the provided predicate. If no value matches the predicate, null is returned
|
||||||
|
*
|
||||||
|
* @param predicate Predicate gor checking
|
||||||
|
* @return The first value matching the predicate, or null if no values match the predicate
|
||||||
|
*/
|
||||||
|
public T getFirstNewAllSatisfiedMatching(@NonNull Predicate<T> predicate) {
|
||||||
|
Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied");
|
||||||
|
|
||||||
|
T t = allSatisfiedQueue.peek();
|
||||||
|
if (predicate.test(t)) {
|
||||||
|
t = allSatisfiedQueue.remove();
|
||||||
|
allSatisfied.remove(t);
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (allSatisfiedQueue.size() > 1) {
|
||||||
|
Iterator<T> iter = allSatisfiedQueue.iterator();
|
||||||
|
while (iter.hasNext()) {
|
||||||
|
t = iter.next();
|
||||||
|
if (predicate.test(t)) {
|
||||||
|
iter.remove();
|
||||||
|
allSatisfied.remove(t);
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null; //None match predicate
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -1,107 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.internal;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
|
||||||
import org.nd4j.autodiff.listeners.At;
|
|
||||||
import org.nd4j.autodiff.listeners.Listener;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Infer datatypes for all variables.
|
|
||||||
* Optionally update the datatypes of variables as we go
|
|
||||||
*/
|
|
||||||
public class DataTypesSession extends AbstractSession<DataType, DataTypesSession.DataTypeCalc> {
|
|
||||||
|
|
||||||
protected boolean dynamicUpdate;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param sameDiff SameDiff instance
|
|
||||||
* @param dynamicUpdate If true: Dynamically update the datatypes as we go
|
|
||||||
*/
|
|
||||||
public DataTypesSession(SameDiff sameDiff, boolean dynamicUpdate) {
|
|
||||||
super(sameDiff);
|
|
||||||
this.dynamicUpdate = dynamicUpdate;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataType getConstantOrVariable(String variableName) {
|
|
||||||
//Variables and constants should always have datatype available
|
|
||||||
DataType dt = sameDiff.getVariable(variableName).dataType();
|
|
||||||
Preconditions.checkNotNull(dt, "No datatype available for variable %s", variableName);
|
|
||||||
return dt;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataTypeCalc getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> inputs, Set<VarId> allIterInputs, Set<String> constAndPhInputs, Map<String, DataType> placeholderValues) {
|
|
||||||
DifferentialFunction df = sameDiff.getOpById(opName);
|
|
||||||
List<DataType> inputDataTypes = new ArrayList<>();
|
|
||||||
for(SDVariable v : df.args()){
|
|
||||||
DataType dt = v.dataType();
|
|
||||||
if(dt != null){
|
|
||||||
inputDataTypes.add(dt);
|
|
||||||
} else {
|
|
||||||
String s = v.getVarName();
|
|
||||||
for(VarId vid : inputs){
|
|
||||||
if(vid.getVariable().equals(s)){
|
|
||||||
DataType dt2 = nodeOutputs.get(vid);
|
|
||||||
Preconditions.checkNotNull(dt2, "No datatype for %s", vid);
|
|
||||||
inputDataTypes.add(dt2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return new DataTypeCalc(df, inputDataTypes);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataType[] getOutputs(DataTypeCalc op, FrameIter outputFrameIter, Set<VarId> inputs, Set<VarId> allIterInputs,
|
|
||||||
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch) {
|
|
||||||
List<DataType> outTypes = op.getFn().calculateOutputDataTypes(op.getInputTypes());
|
|
||||||
|
|
||||||
if(dynamicUpdate) {
|
|
||||||
SDVariable[] fnOutputs = op.getFn().outputVariables();
|
|
||||||
for( int i=0; i<fnOutputs.length; i++ ){
|
|
||||||
SDVariable v = fnOutputs[i];
|
|
||||||
DataType d = outTypes.get(i);
|
|
||||||
if(v.dataType() != d){
|
|
||||||
v.setDataType(d);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return outTypes.toArray(new DataType[outTypes.size()]);
|
|
||||||
}
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Data
|
|
||||||
protected static class DataTypeCalc {
|
|
||||||
protected final DifferentialFunction fn;
|
|
||||||
protected final List<DataType> inputTypes;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
package org.nd4j.autodiff.samediff.internal;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A list of dependencies, used in {@link AbstractDependencyTracker}
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class DependencyList<T, D> {
|
||||||
|
private T dependencyFor;
|
||||||
|
private List<D> dependencies;
|
||||||
|
private List<Pair<D, D>> orDependencies;
|
||||||
|
}
|
|
@ -0,0 +1,38 @@
|
||||||
|
package org.nd4j.autodiff.samediff.internal;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dependenci tracker. See {@link AbstractDependencyTracker} for details
|
||||||
|
*
|
||||||
|
* @param <T> For a dependency X -> Y, Y has type T
|
||||||
|
* @param <D> For a dependency X -> Y, X has type D
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class DependencyTracker<T, D> extends AbstractDependencyTracker<T,D> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Map<T, ?> newTMap() {
|
||||||
|
return new HashMap<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Set<T> newTSet() {
|
||||||
|
return new HashSet<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String toStringT(T t) {
|
||||||
|
return t.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String toStringD(D d) {
|
||||||
|
return d.toString();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
package org.nd4j.autodiff.samediff.internal;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Object dependency tracker, using object identity (not object equality) for the Ys (of type T)<br>
|
||||||
|
* See {@link AbstractDependencyTracker} for more details
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class IdentityDependencyTracker<T, D> extends AbstractDependencyTracker<T,D> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Map<T, ?> newTMap() {
|
||||||
|
return new IdentityHashMap<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Set<T> newTSet() {
|
||||||
|
return Collections.newSetFromMap(new IdentityHashMap<T, Boolean>());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String toStringT(T t) {
|
||||||
|
if(t instanceof INDArray){
|
||||||
|
INDArray i = (INDArray)t;
|
||||||
|
return System.identityHashCode(t) + " - id=" + i.getId() + ", " + i.shapeInfoToString();
|
||||||
|
} else {
|
||||||
|
return System.identityHashCode(t) + " - " + t.toString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String toStringD(D d) {
|
||||||
|
return d.toString();
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -30,8 +30,10 @@ import java.util.List;
|
||||||
@Builder
|
@Builder
|
||||||
public class SameDiffOp {
|
public class SameDiffOp {
|
||||||
protected String name;
|
protected String name;
|
||||||
protected DifferentialFunction op; //Actual op (note: should be mutable: i.e., cloneable, no arrays set)
|
protected DifferentialFunction op; //Actual op (note: should be mutable: i.e., cloneable, no arrays set)
|
||||||
protected List<String> inputsToOp; //Name of SDVariables as input
|
protected List<String> inputsToOp; //Name of SDVariables as input
|
||||||
protected List<String> outputsOfOp; //Name of SDVariables as output
|
protected List<String> outputsOfOp; //Name of SDVariables as output
|
||||||
protected List<String> controlDeps; //Name of SDVariables as control dependencies (not data inputs, but need to be available before exec)
|
protected List<String> controlDeps; //Name of SDVariables as control dependencies (not data inputs, but need to be available before exec)
|
||||||
|
protected List<String> varControlDeps; //Variables (constants, placeholders, etc) that are control dependencies for this op
|
||||||
|
protected List<String> controlDepFor; //Name of the variables that this op is a control dependency for
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
package org.nd4j.autodiff.samediff.internal;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
|
||||||
|
import java.io.Closeable;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SessionMemMgr - aka "Session Memory Manager" is responsible for allocating, managing, and deallocating memory used
|
||||||
|
* during SameDiff execution.<br>
|
||||||
|
* This interface allows different memory management strategies to be used, abstracted away from the actual graph
|
||||||
|
* execution logic
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public interface SessionMemMgr extends Closeable {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allocate an array with the specified datatype and shape.<br>
|
||||||
|
* NOTE: This array should be assumed to be uninitialized - i.e., contains random values.
|
||||||
|
*
|
||||||
|
* @param detached If true: the array is safe to return outside of the SameDiff session run (for example, the array
|
||||||
|
* is one that may be returned to the user)
|
||||||
|
* @param dataType Datatype of the returned array
|
||||||
|
* @param shape Array shape
|
||||||
|
* @return The newly allocated (uninitialized) array
|
||||||
|
*/
|
||||||
|
INDArray allocate(boolean detached, DataType dataType, long... shape);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* As per {@link #allocate(boolean, DataType, long...)} but from a LongShapeDescriptor instead
|
||||||
|
*/
|
||||||
|
INDArray allocate(boolean detached, LongShapeDescriptor descriptor);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allocate an uninitialized array with the same datatype and shape as the specified array
|
||||||
|
*/
|
||||||
|
INDArray ulike(INDArray arr);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Duplicate the specified array, to an array that is managed/allocated by the session memory manager
|
||||||
|
*/
|
||||||
|
INDArray dup(INDArray arr);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Release the array. All arrays allocated via one of the allocate methods should be returned here once they are no
|
||||||
|
* longer used, and all references to them should be cleared.
|
||||||
|
* After calling release, anything could occur to the array - deallocated, workspace closed, reused, etc.
|
||||||
|
*
|
||||||
|
* @param array The array that can be released
|
||||||
|
*/
|
||||||
|
void release(INDArray array);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Close the session memory manager and clean up any memory / resources, if any
|
||||||
|
*/
|
||||||
|
void close();
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,232 @@
|
||||||
|
package org.nd4j.autodiff.samediff.internal;
|
||||||
|
|
||||||
|
import com.sun.prism.paint.Gradient;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.autodiff.listeners.At;
|
||||||
|
import org.nd4j.autodiff.listeners.Listener;
|
||||||
|
import org.nd4j.autodiff.listeners.Loss;
|
||||||
|
import org.nd4j.autodiff.listeners.Operation;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||||
|
import org.nd4j.autodiff.samediff.VariableType;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
import org.nd4j.linalg.learning.GradientUpdater;
|
||||||
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
|
import org.nd4j.linalg.primitives.AtomicDouble;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* TrainingSession extends InferenceSession, to add training-specific functionality:<br>
|
||||||
|
* - Application of regularization (L1, L2, weight decay etc)<br>
|
||||||
|
* - Inline updating of variables, using updater/optimizer (Adam, Nesterov, SGD, etc)<br>
|
||||||
|
* - Calculation of regularization scores (Score for L1, L2, etc)
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class TrainingSession extends InferenceSession {
|
||||||
|
|
||||||
|
protected TrainingConfig config;
|
||||||
|
protected Map<String, String> gradVarToVarMap;
|
||||||
|
protected Map<String, GradientUpdater> updaters;
|
||||||
|
protected Map<String, Integer> lossVarsToLossIdx;
|
||||||
|
protected double[] currIterLoss;
|
||||||
|
protected Map<Class<?>, AtomicDouble> currIterRegLoss;
|
||||||
|
protected List<Listener> listeners;
|
||||||
|
|
||||||
|
|
||||||
|
public TrainingSession(SameDiff sameDiff) {
|
||||||
|
super(sameDiff);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform one iteration of training - i.e., do forward and backward passes, and update the parameters
|
||||||
|
*
|
||||||
|
* @param config Training configuration
|
||||||
|
* @param placeholders Current placeholders
|
||||||
|
* @param paramsToTrain Set of parameters that will be trained
|
||||||
|
* @param updaters Current updater state
|
||||||
|
* @param batch Current data/batch (mainly for listeners, should have already been converted to placeholders map)
|
||||||
|
* @param lossVariables Loss variables (names)
|
||||||
|
* @param listeners Listeners (if any)
|
||||||
|
* @param at Current epoch, iteration, etc
|
||||||
|
* @return The Loss at the current iteration
|
||||||
|
*/
|
||||||
|
public Loss trainingIteration(TrainingConfig config, Map<String, INDArray> placeholders, Set<String> paramsToTrain, Map<String, GradientUpdater> updaters,
|
||||||
|
MultiDataSet batch, List<String> lossVariables, List<Listener> listeners, At at) {
|
||||||
|
this.config = config;
|
||||||
|
this.updaters = updaters;
|
||||||
|
|
||||||
|
//Preprocess listeners, get the relevant ones
|
||||||
|
if (listeners == null) {
|
||||||
|
this.listeners = null;
|
||||||
|
} else {
|
||||||
|
List<Listener> filtered = new ArrayList<>();
|
||||||
|
for (Listener l : listeners) {
|
||||||
|
if (l.isActive(at.operation())) {
|
||||||
|
filtered.add(l);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.listeners = filtered.isEmpty() ? null : filtered;
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> requiredActivations = new ArrayList<>();
|
||||||
|
gradVarToVarMap = new HashMap<>(); //Key: gradient variable. Value: variable that the key is gradient for
|
||||||
|
for (String s : paramsToTrain) {
|
||||||
|
Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s);
|
||||||
|
SDVariable v = sameDiff.getVariable(s);
|
||||||
|
Preconditions.checkState(v.getVariableType() == VariableType.VARIABLE, "Can only train VARIABLE type variable - \"%s\" has type %s",
|
||||||
|
s, v.getVariableType());
|
||||||
|
SDVariable grad = sameDiff.getVariable(s).getGradient();
|
||||||
|
if (grad == null) {
|
||||||
|
//In some cases, a variable won't actually impact the loss value, and hence won't have a gradient associated with it
|
||||||
|
//For example: floatVar -> cast to integer -> cast to float -> sum -> loss
|
||||||
|
//In this case, the gradient of floatVar isn't defined (due to no floating point connection to the loss)
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
requiredActivations.add(grad.getVarName());
|
||||||
|
|
||||||
|
gradVarToVarMap.put(grad.getVarName(), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Set up losses
|
||||||
|
lossVarsToLossIdx = new LinkedHashMap<>();
|
||||||
|
List<String> lossVars;
|
||||||
|
currIterLoss = new double[lossVariables.size()];
|
||||||
|
currIterRegLoss = new HashMap<>();
|
||||||
|
for (int i = 0; i < lossVariables.size(); i++) {
|
||||||
|
lossVarsToLossIdx.put(lossVariables.get(i), i);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Do training iteration
|
||||||
|
List<String> outputVars = new ArrayList<>(gradVarToVarMap.keySet()); //TODO this should be empty, and grads calculated in requiredActivations
|
||||||
|
Map<String, INDArray> m = output(outputVars, placeholders, batch, requiredActivations, listeners, at);
|
||||||
|
|
||||||
|
|
||||||
|
double[] finalLoss = new double[currIterLoss.length + currIterRegLoss.size()];
|
||||||
|
System.arraycopy(currIterLoss, 0, finalLoss, 0, currIterLoss.length);
|
||||||
|
if (currIterRegLoss.size() > 0) {
|
||||||
|
lossVars = new ArrayList<>(lossVariables.size() + currIterRegLoss.size());
|
||||||
|
lossVars.addAll(lossVariables);
|
||||||
|
int s = currIterRegLoss.size();
|
||||||
|
//Collect regularization losses
|
||||||
|
for (Map.Entry<Class<?>, AtomicDouble> entry : currIterRegLoss.entrySet()) {
|
||||||
|
lossVars.add(entry.getKey().getSimpleName());
|
||||||
|
finalLoss[s] = entry.getValue().get();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
lossVars = lossVariables;
|
||||||
|
}
|
||||||
|
|
||||||
|
Loss loss = new Loss(lossVars, finalLoss);
|
||||||
|
if (listeners != null) {
|
||||||
|
for (Listener l : listeners) {
|
||||||
|
if (l.isActive(Operation.TRAINING)) {
|
||||||
|
l.iterationDone(sameDiff, at, batch, loss);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return loss;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||||
|
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables) {
|
||||||
|
//Get outputs from InferenceSession
|
||||||
|
INDArray[] out = super.getOutputs(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables);
|
||||||
|
|
||||||
|
List<String> outputs = op.getOutputsOfOp();
|
||||||
|
int outIdx = 0;
|
||||||
|
for (String s : outputs) {
|
||||||
|
//If this is a loss variable - record it
|
||||||
|
if (lossVarsToLossIdx.containsKey(s)) {
|
||||||
|
int lossIdx = lossVarsToLossIdx.get(s);
|
||||||
|
INDArray arr = out[outIdx];
|
||||||
|
double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue();
|
||||||
|
currIterLoss[lossIdx] += l;
|
||||||
|
}
|
||||||
|
|
||||||
|
//If this is a gradient variable - apply the updater and update the parameter array in-line
|
||||||
|
if (gradVarToVarMap.containsKey(s)) {
|
||||||
|
String varName = gradVarToVarMap.get(s);
|
||||||
|
//log.info("Calculated gradient for variable \"{}\": (grad var name: \"{}\")", varName, s);
|
||||||
|
|
||||||
|
Variable gradVar = sameDiff.getVariables().get(s);
|
||||||
|
if (gradVar.getInputsForOp() != null && gradVar.getInputsForOp().isEmpty()) {
|
||||||
|
//Should be rare, and we should handle this by tracking dependencies, and only update when safe
|
||||||
|
// (i.e., dependency tracking)
|
||||||
|
throw new IllegalStateException("Op depends on gradient variable: " + s + " for variable " + varName);
|
||||||
|
}
|
||||||
|
|
||||||
|
GradientUpdater u = updaters.get(varName);
|
||||||
|
Preconditions.checkState(u != null, "No updater found for variable \"%s\"", varName);
|
||||||
|
|
||||||
|
Variable var = sameDiff.getVariables().get(varName);
|
||||||
|
INDArray gradArr = out[outIdx];
|
||||||
|
INDArray paramArr = var.getVariable().getArr();
|
||||||
|
|
||||||
|
//Pre-updater regularization (L1, L2)
|
||||||
|
List<Regularization> r = config.getRegularization();
|
||||||
|
if (r != null && r.size() > 0) {
|
||||||
|
double lr = config.getUpdater().hasLearningRate() ? config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0;
|
||||||
|
for (Regularization reg : r) {
|
||||||
|
if (reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) {
|
||||||
|
if (this.listeners != null) {
|
||||||
|
double score = reg.score(paramArr, at.iteration(), at.epoch());
|
||||||
|
if (!currIterRegLoss.containsKey(reg.getClass())) {
|
||||||
|
currIterRegLoss.put(reg.getClass(), new AtomicDouble());
|
||||||
|
}
|
||||||
|
currIterRegLoss.get(reg.getClass()).addAndGet(score);
|
||||||
|
}
|
||||||
|
reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
u.applyUpdater(gradArr, at.iteration(), at.epoch());
|
||||||
|
|
||||||
|
//Post-apply regularization (weight decay)
|
||||||
|
if (r != null && r.size() > 0) {
|
||||||
|
double lr = config.getUpdater().hasLearningRate() ? config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0;
|
||||||
|
for (Regularization reg : r) {
|
||||||
|
if (reg.applyStep() == Regularization.ApplyStep.POST_UPDATER) {
|
||||||
|
if (this.listeners != null) {
|
||||||
|
double score = reg.score(paramArr, at.iteration(), at.epoch());
|
||||||
|
if (!currIterRegLoss.containsKey(reg.getClass())) {
|
||||||
|
currIterRegLoss.put(reg.getClass(), new AtomicDouble());
|
||||||
|
}
|
||||||
|
currIterRegLoss.get(reg.getClass()).addAndGet(score);
|
||||||
|
}
|
||||||
|
reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (listeners != null) {
|
||||||
|
for (Listener l : listeners) {
|
||||||
|
if (l.isActive(at.operation()))
|
||||||
|
l.preUpdate(sameDiff, at, var, gradArr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Update:
|
||||||
|
if (config.isMinimize()) {
|
||||||
|
paramArr.subi(gradArr);
|
||||||
|
} else {
|
||||||
|
paramArr.addi(gradArr);
|
||||||
|
}
|
||||||
|
log.trace("Applied updater to gradient and updated variable: {}", varName);
|
||||||
|
}
|
||||||
|
|
||||||
|
outIdx++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
}
|
|
@ -35,8 +35,7 @@ public class Variable {
|
||||||
protected List<String> controlDepsForOp; //if a op control dependency (x -> opY) exists, then "opY" will be in this list
|
protected List<String> controlDepsForOp; //if a op control dependency (x -> opY) exists, then "opY" will be in this list
|
||||||
protected List<String> controlDepsForVar; //if a variable control dependency (x -> varY) exists, then "varY" will be in this list
|
protected List<String> controlDepsForVar; //if a variable control dependency (x -> varY) exists, then "varY" will be in this list
|
||||||
protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
||||||
protected List<String> controlDeps; //Control dependencies: name of variables that must be available before this variable is considered available for execution
|
protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution
|
||||||
protected int outputOfOpIdx; //Index of the output for the op (say, variable is output number 2 of op "outputOfOp")
|
|
||||||
protected SDVariable gradient; //Variable corresponding to the gradient of this variable
|
protected SDVariable gradient; //Variable corresponding to the gradient of this variable
|
||||||
protected int variableIndex = -1;
|
protected int variableIndex = -1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
package org.nd4j.autodiff.samediff.internal.memory;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract memory manager, that implements ulike and dup methods using the underlying allocate methods
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public abstract class AbstractMemoryMgr implements SessionMemMgr {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray ulike(@NonNull INDArray arr) {
|
||||||
|
return allocate(false, arr.dataType(), arr.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray dup(@NonNull INDArray arr) {
|
||||||
|
INDArray out = ulike(arr);
|
||||||
|
out.assign(arr);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
package org.nd4j.autodiff.samediff.internal.memory;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple memory management strategy that deallocates memory as soon as it is no longer needed.<br>
|
||||||
|
* This should result in a minimal amount of memory, but will have some overhead - notably, the cost of deallocating
|
||||||
|
* and reallocating memory all the time.
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class ArrayCloseMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
|
||||||
|
return Nd4j.createUninitialized(dataType, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
|
||||||
|
return Nd4j.create(descriptor, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void release(@NonNull INDArray array) {
|
||||||
|
if (!array.wasClosed() && array.closeable()) {
|
||||||
|
array.close();
|
||||||
|
log.trace("Closed array (deallocated) - id={}", array.getId());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
//No-op
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,168 @@
|
||||||
|
package org.nd4j.autodiff.samediff.internal.memory;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.VariableType;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.DependencyList;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A {@link SessionMemMgr} that wraps an existing memory manager, to ensure that:<br>
|
||||||
|
* - All arrays that are supposed to be closed, have been closed<br>
|
||||||
|
* - Arrays are only passed to the close method exactly one (unless they are requested outputs)<br>
|
||||||
|
* - Arrays that are passed to the close method were originally allocated by the session memory manager<br>
|
||||||
|
* <br>
|
||||||
|
* How to use:<br>
|
||||||
|
* 1. Perform an inference or training iteration, as normal<br>
|
||||||
|
* 2. Call {@link #assertAllReleasedExcept(Collection)} with the output arrays<br>
|
||||||
|
* <p>
|
||||||
|
* NOTE: This is intended for debugging and testing only
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class CloseValidationMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr {
|
||||||
|
|
||||||
|
private final SameDiff sd;
|
||||||
|
private final SessionMemMgr underlying;
|
||||||
|
private final Map<INDArray, Boolean> released = new IdentityHashMap<>();
|
||||||
|
|
||||||
|
public CloseValidationMemoryMgr(SameDiff sd, SessionMemMgr underlying) {
|
||||||
|
this.sd = sd;
|
||||||
|
this.underlying = underlying;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
|
||||||
|
INDArray out = underlying.allocate(detached, dataType, shape);
|
||||||
|
released.put(out, false);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
|
||||||
|
INDArray out = underlying.allocate(detached, descriptor);
|
||||||
|
released.put(out, false);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void release(INDArray array) {
|
||||||
|
Preconditions.checkState(released.containsKey(array), "Attempting to release an array that was not allocated by" +
|
||||||
|
" this memory manager: id=%s", array.getId());
|
||||||
|
if (released.get(array)) {
|
||||||
|
//Already released
|
||||||
|
InferenceSession is = sd.getSessions().get(Thread.currentThread().getId());
|
||||||
|
IdentityDependencyTracker<INDArray, InferenceSession.Dep> arrayUseTracker = is.getArrayUseTracker();
|
||||||
|
DependencyList<INDArray, InferenceSession.Dep> dl = arrayUseTracker.getDependencies(array);
|
||||||
|
System.out.println(dl);
|
||||||
|
if (dl.getDependencies() != null) {
|
||||||
|
for (InferenceSession.Dep d : dl.getDependencies()) {
|
||||||
|
System.out.println(d + ": " + arrayUseTracker.isSatisfied(d));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (dl.getOrDependencies() != null) {
|
||||||
|
for (Pair<InferenceSession.Dep, InferenceSession.Dep> p : dl.getOrDependencies()) {
|
||||||
|
System.out.println(p + " - (" + arrayUseTracker.isSatisfied(p.getFirst()) + "," + arrayUseTracker.isSatisfied(p.getSecond()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Preconditions.checkState(!released.get(array), "Attempting to release an array that was already deallocated by" +
|
||||||
|
" an earlier release call to this memory manager: id=%s", array.getId());
|
||||||
|
log.trace("Released array: id = {}", array.getId());
|
||||||
|
released.put(array, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
underlying.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check that all arrays have been released (after an inference call) except for the specified arrays.
|
||||||
|
*
|
||||||
|
* @param except Arrays that should not have been closed (usually network outputs)
|
||||||
|
*/
|
||||||
|
public void assertAllReleasedExcept(@NonNull Collection<INDArray> except) {
|
||||||
|
Set<INDArray> allVarPhConst = null;
|
||||||
|
|
||||||
|
for (INDArray arr : except) {
|
||||||
|
if (!released.containsKey(arr)) {
|
||||||
|
//Check if constant, variable or placeholder - maybe user requested that out
|
||||||
|
if (allVarPhConst == null)
|
||||||
|
allVarPhConst = identitySetAllConstPhVar();
|
||||||
|
if (allVarPhConst.contains(arr))
|
||||||
|
continue; //OK - output is a constant, variable or placeholder, hence it's fine it's not allocated by the memory manager
|
||||||
|
|
||||||
|
throw new IllegalStateException("Array " + arr.getId() + " was not originally allocated by the memory manager");
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean released = this.released.get(arr);
|
||||||
|
if (released) {
|
||||||
|
throw new IllegalStateException("Specified output array (id=" + arr.getId() + ") should not have been deallocated but was");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Set<INDArray> exceptSet = Collections.newSetFromMap(new IdentityHashMap<INDArray, Boolean>());
|
||||||
|
exceptSet.addAll(except);
|
||||||
|
|
||||||
|
int numNotClosed = 0;
|
||||||
|
Set<INDArray> notReleased = Collections.newSetFromMap(new IdentityHashMap<INDArray, Boolean>());
|
||||||
|
InferenceSession is = sd.getSessions().get(Thread.currentThread().getId());
|
||||||
|
IdentityDependencyTracker<INDArray, InferenceSession.Dep> arrayUseTracker = is.getArrayUseTracker();
|
||||||
|
for (Map.Entry<INDArray, Boolean> e : released.entrySet()) {
|
||||||
|
INDArray a = e.getKey();
|
||||||
|
if (!exceptSet.contains(a)) {
|
||||||
|
boolean b = e.getValue();
|
||||||
|
if (!b) {
|
||||||
|
notReleased.add(a);
|
||||||
|
numNotClosed++;
|
||||||
|
log.info("Not released: array id {}", a.getId());
|
||||||
|
DependencyList<INDArray, InferenceSession.Dep> list = arrayUseTracker.getDependencies(a);
|
||||||
|
List<InferenceSession.Dep> l = list.getDependencies();
|
||||||
|
List<Pair<InferenceSession.Dep, InferenceSession.Dep>> l2 = list.getOrDependencies();
|
||||||
|
if (l != null) {
|
||||||
|
for (InferenceSession.Dep d : l) {
|
||||||
|
if (!arrayUseTracker.isSatisfied(d)) {
|
||||||
|
log.info(" Not satisfied: {}", d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (l2 != null) {
|
||||||
|
for (Pair<InferenceSession.Dep, InferenceSession.Dep> d : l2) {
|
||||||
|
if (!arrayUseTracker.isSatisfied(d.getFirst()) && !arrayUseTracker.isSatisfied(d.getSecond())) {
|
||||||
|
log.info(" Not satisfied: {}", d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (numNotClosed > 0) {
|
||||||
|
System.out.println(sd.summary());
|
||||||
|
throw new IllegalStateException(numNotClosed + " arrays were not released but should have been");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Set<INDArray> identitySetAllConstPhVar() {
|
||||||
|
Set<INDArray> set = Collections.newSetFromMap(new IdentityHashMap<INDArray, Boolean>());
|
||||||
|
for (SDVariable v : sd.variables()) {
|
||||||
|
if (v.getVariableType() == VariableType.VARIABLE || v.getVariableType() == VariableType.CONSTANT || v.getVariableType() == VariableType.PLACEHOLDER) {
|
||||||
|
set.add(v.getArr());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return set;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
package org.nd4j.autodiff.samediff.internal.memory;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple "no-op" memory manager that relies on JVM garbage collector for memory management.
|
||||||
|
* Assuming other references have been cleared (they should have been) the arrays will be cleaned up by the
|
||||||
|
* garbage collector at some point.
|
||||||
|
*
|
||||||
|
* This memory management strategy is not recommended for performance or memory reasons, and should only be used
|
||||||
|
* for testing and debugging purposes
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class NoOpMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
|
||||||
|
return Nd4j.createUninitialized(dataType, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
|
||||||
|
return Nd4j.create(descriptor, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void release(@NonNull INDArray array) {
|
||||||
|
//No-op, rely on GC to clear arrays
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
//No-op
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -90,10 +90,10 @@ public class SDNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @see #biasAdd(String, SDVariable, SDVariable)
|
* @see #biasAdd(String, SDVariable, SDVariable, boolean)
|
||||||
*/
|
*/
|
||||||
public SDVariable biasAdd(SDVariable input, SDVariable bias) {
|
public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) {
|
||||||
return biasAdd(null, input, bias);
|
return biasAdd(null, input, bias, nchw);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -102,12 +102,14 @@ public class SDNN extends SDOps {
|
||||||
* @param name Name of the output variable
|
* @param name Name of the output variable
|
||||||
* @param input 4d input variable
|
* @param input 4d input variable
|
||||||
* @param bias 1d bias
|
* @param bias 1d bias
|
||||||
|
* @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels].
|
||||||
|
* Unused for 2d inputs
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
*/
|
*/
|
||||||
public SDVariable biasAdd(String name, SDVariable input, SDVariable bias) {
|
public SDVariable biasAdd(String name, SDVariable input, SDVariable bias, boolean nchw) {
|
||||||
validateFloatingPoint("biasAdd", "input", input);
|
validateFloatingPoint("biasAdd", "input", input);
|
||||||
validateFloatingPoint("biasAdd", "bias", bias);
|
validateFloatingPoint("biasAdd", "bias", bias);
|
||||||
SDVariable ret = f().biasAdd(input, bias);
|
SDVariable ret = f().biasAdd(input, bias, nchw);
|
||||||
return updateVariableNameAndReference(ret, name);
|
return updateVariableNameAndReference(ret, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.serde;
|
package org.nd4j.autodiff.samediff.serde;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
import org.nd4j.shade.guava.primitives.Ints;
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
|
@ -847,6 +848,28 @@ public class FlatBuffersMapper {
|
||||||
}
|
}
|
||||||
int outTypesOffset = FlatNode.createOutputTypesVector(bufferBuilder, outTypes);
|
int outTypesOffset = FlatNode.createOutputTypesVector(bufferBuilder, outTypes);
|
||||||
|
|
||||||
|
//Control dependencies:
|
||||||
|
SameDiffOp sdo = sameDiff.getOps().get(node.getOwnName());
|
||||||
|
|
||||||
|
int opCds = 0;
|
||||||
|
int[] opCdsArr = mapOrNull(sdo.getControlDeps(), bufferBuilder);
|
||||||
|
if(opCdsArr != null){
|
||||||
|
opCds = FlatNode.createControlDepsVector(bufferBuilder, opCdsArr);
|
||||||
|
}
|
||||||
|
|
||||||
|
int varCds = 0;
|
||||||
|
int[] varCdsArr = mapOrNull(sdo.getVarControlDeps(), bufferBuilder);
|
||||||
|
if(varCdsArr != null){
|
||||||
|
varCds = FlatNode.createVarControlDepsVector(bufferBuilder, varCdsArr);
|
||||||
|
}
|
||||||
|
|
||||||
|
int cdsFor = 0;
|
||||||
|
int[] cdsForArr = mapOrNull(sdo.getControlDepFor(), bufferBuilder);
|
||||||
|
if(cdsForArr != null){
|
||||||
|
cdsFor = FlatNode.createControlDepForVector(bufferBuilder, cdsForArr);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
int flatNode = FlatNode.createFlatNode(
|
int flatNode = FlatNode.createFlatNode(
|
||||||
bufferBuilder,
|
bufferBuilder,
|
||||||
ownId,
|
ownId,
|
||||||
|
@ -867,12 +890,26 @@ public class FlatBuffersMapper {
|
||||||
outVarNamesOffset,
|
outVarNamesOffset,
|
||||||
opNameOffset,
|
opNameOffset,
|
||||||
outTypesOffset, //Output types
|
outTypesOffset, //Output types
|
||||||
scalar
|
scalar,
|
||||||
|
opCds,
|
||||||
|
varCds,
|
||||||
|
cdsFor
|
||||||
);
|
);
|
||||||
|
|
||||||
return flatNode;
|
return flatNode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static int[] mapOrNull(List<String> list, FlatBufferBuilder fbb){
|
||||||
|
if(list == null)
|
||||||
|
return null;
|
||||||
|
int[] out = new int[list.size()];
|
||||||
|
int i=0;
|
||||||
|
for(String s : list){
|
||||||
|
out[i++] = fbb.createString(s);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
public static DifferentialFunction cloneViaSerialize(SameDiff sd, DifferentialFunction df ){
|
public static DifferentialFunction cloneViaSerialize(SameDiff sd, DifferentialFunction df ){
|
||||||
Map<String,Integer> nameToIdxMap = new HashMap<>();
|
Map<String,Integer> nameToIdxMap = new HashMap<>();
|
||||||
int count = 0;
|
int count = 0;
|
||||||
|
|
|
@ -131,12 +131,12 @@ public class GradCheckUtil {
|
||||||
// in this case, gradients of x and y are all 0 too
|
// in this case, gradients of x and y are all 0 too
|
||||||
|
|
||||||
//Collect variables to get gradients for - we want placeholders AND variables
|
//Collect variables to get gradients for - we want placeholders AND variables
|
||||||
Set<String> gradVarNames = new HashSet<>();
|
Set<String> varsNeedingGrads = new HashSet<>();
|
||||||
for(Variable v : sd.getVariables().values()){
|
for(Variable v : sd.getVariables().values()){
|
||||||
if(v.getVariable().dataType().isFPType() && (v.getVariable().getVariableType() == VariableType.VARIABLE || v.getVariable().getVariableType() == VariableType.PLACEHOLDER)){
|
if(v.getVariable().dataType().isFPType() && (v.getVariable().getVariableType() == VariableType.VARIABLE || v.getVariable().getVariableType() == VariableType.PLACEHOLDER)){
|
||||||
SDVariable g = v.getVariable().getGradient();
|
SDVariable g = v.getVariable().getGradient();
|
||||||
Preconditions.checkNotNull(g, "No gradient variable found for variable %s", v.getVariable());
|
Preconditions.checkNotNull(g, "No gradient variable found for variable %s", v.getVariable());
|
||||||
gradVarNames.add(g.getVarName());
|
varsNeedingGrads.add(v.getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,7 +164,7 @@ public class GradCheckUtil {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
sd.execBackwards(placeholderValues, new ArrayList<>(gradVarNames));
|
Map<String,INDArray> gm = sd.calculateGradients(placeholderValues, varsNeedingGrads);
|
||||||
|
|
||||||
//Remove listener, to reduce overhead
|
//Remove listener, to reduce overhead
|
||||||
sd.getListeners().remove(listenerIdx);
|
sd.getListeners().remove(listenerIdx);
|
||||||
|
@ -183,11 +183,11 @@ public class GradCheckUtil {
|
||||||
if(g == null){
|
if(g == null){
|
||||||
throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\"");
|
throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\"");
|
||||||
}
|
}
|
||||||
INDArray ga = g.getArr();
|
INDArray ga = gm.get(v.getVarName());
|
||||||
if(ga == null){
|
if(ga == null){
|
||||||
throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName());
|
throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName());
|
||||||
}
|
}
|
||||||
if(!Arrays.equals(v.getArr().shape(), g.getArr().shape())){
|
if(!Arrays.equals(v.getArr().shape(), ga.shape())){
|
||||||
throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" +
|
throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" +
|
||||||
v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " +
|
v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " +
|
||||||
Arrays.toString(ga.shape()));
|
Arrays.toString(ga.shape()));
|
||||||
|
@ -408,18 +408,18 @@ public class GradCheckUtil {
|
||||||
|
|
||||||
//Collect names of variables to get gradients for - i.e., the names of the GRADIENT variables for the specified activations
|
//Collect names of variables to get gradients for - i.e., the names of the GRADIENT variables for the specified activations
|
||||||
sd.createGradFunction();
|
sd.createGradFunction();
|
||||||
Set<String> gradVarNames = new HashSet<>();
|
Set<String> varsRequiringGrads = new HashSet<>();
|
||||||
for(String s : actGrads){
|
for(String s : actGrads){
|
||||||
SDVariable grad = sd.getVariable(s).gradient();
|
SDVariable grad = sd.getVariable(s).gradient();
|
||||||
Preconditions.checkState( grad != null,"Could not get gradient for activation \"%s\": gradient variable is null", s);
|
Preconditions.checkState( grad != null,"Could not get gradient for activation \"%s\": gradient variable is null", s);
|
||||||
gradVarNames.add(grad.getVarName());
|
varsRequiringGrads.add(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Calculate analytical gradients
|
//Calculate analytical gradients
|
||||||
sd.execBackwards(config.getPlaceholderValues(), new ArrayList<>(gradVarNames));
|
Map<String,INDArray> grads = sd.calculateGradients(config.getPlaceholderValues(), new ArrayList<>(varsRequiringGrads));
|
||||||
Map<String,INDArray> gradientsForAct = new HashMap<>();
|
Map<String,INDArray> gradientsForAct = new HashMap<>();
|
||||||
for(String s : actGrads){
|
for(String s : actGrads){
|
||||||
INDArray arr = sd.getVariable(s).gradient().getArr();
|
INDArray arr = grads.get(s);
|
||||||
Preconditions.checkState(arr != null, "No activation gradient array for variable \"%s\"", s);
|
Preconditions.checkState(arr != null, "No activation gradient array for variable \"%s\"", s);
|
||||||
gradientsForAct.put(s, arr.dup());
|
gradientsForAct.put(s, arr.dup());
|
||||||
}
|
}
|
||||||
|
|
|
@ -190,11 +190,13 @@ public class OpValidation {
|
||||||
//Check forward pass:
|
//Check forward pass:
|
||||||
if (testCase.fwdTestFns() != null && testCase.fwdTestFns().size() > 0) {
|
if (testCase.fwdTestFns() != null && testCase.fwdTestFns().size() > 0) {
|
||||||
SameDiff sd = testCase.sameDiff();
|
SameDiff sd = testCase.sameDiff();
|
||||||
|
|
||||||
|
//Collect variables we need outputs for...
|
||||||
|
Set<String> reqVars = testCase.fwdTestFns().keySet();
|
||||||
|
|
||||||
|
Map<String,INDArray> out;
|
||||||
try {
|
try {
|
||||||
if(testCase.placeholderValues() != null){
|
out = sd.output(testCase.placeholderValues(), new ArrayList<>(reqVars));
|
||||||
sd.resolveVariablesWith(testCase.placeholderValues());
|
|
||||||
}
|
|
||||||
sd.exec(null, sd.outputs());
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException("Error during forward pass testing" + testCase.testNameErrMsg(), e);
|
throw new RuntimeException("Error during forward pass testing" + testCase.testNameErrMsg(), e);
|
||||||
}
|
}
|
||||||
|
@ -206,7 +208,7 @@ public class OpValidation {
|
||||||
e.getKey() + "\" but SameDiff instance does not have a variable for this name" + testCase.testNameErrMsg());
|
e.getKey() + "\" but SameDiff instance does not have a variable for this name" + testCase.testNameErrMsg());
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray actual = v.getArr();
|
INDArray actual = out.get(v.getVarName());
|
||||||
if (actual == null) {
|
if (actual == null) {
|
||||||
throw new IllegalStateException("Null INDArray after forward pass for variable \"" + e.getKey() + "\"");
|
throw new IllegalStateException("Null INDArray after forward pass for variable \"" + e.getKey() + "\"");
|
||||||
}
|
}
|
||||||
|
@ -291,6 +293,12 @@ public class OpValidation {
|
||||||
Preconditions.checkState((orig.getControlDeps() == null) == (des.getControlDeps() == null), "Control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps());
|
Preconditions.checkState((orig.getControlDeps() == null) == (des.getControlDeps() == null), "Control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps());
|
||||||
Preconditions.checkState(orig.getControlDeps() == null || orig.getControlDeps().equals(des.getControlDeps()), "Control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps());
|
Preconditions.checkState(orig.getControlDeps() == null || orig.getControlDeps().equals(des.getControlDeps()), "Control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps());
|
||||||
|
|
||||||
|
Preconditions.checkState((orig.getVarControlDeps() == null) == (des.getVarControlDeps() == null), "Op variable control dependencies differ: %s vs. %s", orig.getVarControlDeps(), des.getVarControlDeps());
|
||||||
|
Preconditions.checkState(orig.getVarControlDeps() == null || orig.getVarControlDeps().equals(des.getVarControlDeps()), "Op variable control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps());
|
||||||
|
|
||||||
|
Preconditions.checkState((orig.getControlDepFor() == null) == (des.getControlDepFor() == null), "Op control dependencies for list differ: %s vs. %s", orig.getControlDepFor(), des.getControlDepFor());
|
||||||
|
Preconditions.checkState(orig.getControlDepFor() == null || orig.getControlDepFor().equals(des.getControlDepFor()), "Op variable control dependencies differ: %s vs. %s", orig.getControlDepFor(), des.getControlDepFor());
|
||||||
|
|
||||||
Preconditions.checkState(orig.getOp().getClass() == des.getOp().getClass(), "Classes differ: %s v. %s", orig.getOp().getClass(), des.getOp().getClass());
|
Preconditions.checkState(orig.getOp().getClass() == des.getOp().getClass(), "Classes differ: %s v. %s", orig.getOp().getClass(), des.getOp().getClass());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -317,6 +325,11 @@ public class OpValidation {
|
||||||
Map<String,Variable> varsBefore = original.getVariables();
|
Map<String,Variable> varsBefore = original.getVariables();
|
||||||
Map<String,Variable> varsAfter = deserialized.getVariables();
|
Map<String,Variable> varsAfter = deserialized.getVariables();
|
||||||
Preconditions.checkState(varsBefore.keySet().equals(varsAfter.keySet()), "Variable keysets do not match: %s vs %s", varsBefore.keySet(), varsAfter.keySet());
|
Preconditions.checkState(varsBefore.keySet().equals(varsAfter.keySet()), "Variable keysets do not match: %s vs %s", varsBefore.keySet(), varsAfter.keySet());
|
||||||
|
|
||||||
|
// System.out.println(original.summary());
|
||||||
|
// System.out.println("\n\n\n\n");
|
||||||
|
// System.out.println(deserialized.summary());
|
||||||
|
|
||||||
for(String s : varsBefore.keySet()){
|
for(String s : varsBefore.keySet()){
|
||||||
Variable vB = varsBefore.get(s);
|
Variable vB = varsBefore.get(s);
|
||||||
Variable vA = varsAfter.get(s);
|
Variable vA = varsAfter.get(s);
|
||||||
|
@ -324,13 +337,15 @@ public class OpValidation {
|
||||||
Preconditions.checkState(vB.getVariable().getVariableType() == vA.getVariable().getVariableType(),
|
Preconditions.checkState(vB.getVariable().getVariableType() == vA.getVariable().getVariableType(),
|
||||||
"Variable types do not match: %s - %s vs %s", s, vB.getVariable().getVariableType(), vA.getVariable().getVariableType());
|
"Variable types do not match: %s - %s vs %s", s, vB.getVariable().getVariableType(), vA.getVariable().getVariableType());
|
||||||
|
|
||||||
Preconditions.checkState((vB.getInputsForOp() == null) == (vA.getInputsForOp() == null), "Input to ops differ: %s vs. %s", vB.getInputsForOp(), vA.getInputsForOp());
|
equalConsideringNull(vB.getInputsForOp(), vA.getInputsForOp(), "%s - Input to ops differ: %s vs. %s", s, vB.getInputsForOp(), vA.getInputsForOp());
|
||||||
Preconditions.checkState(vB.getInputsForOp() == null || vB.getInputsForOp().equals(vA.getInputsForOp()), "Inputs differ: %s vs. %s", vB.getInputsForOp(), vA.getInputsForOp());
|
|
||||||
|
|
||||||
Preconditions.checkState((vB.getOutputOfOp() == null && vA.getOutputOfOp() == null) || vB.getOutputOfOp().equals(vA.getOutputOfOp()), "Output of op differ: %s vs. %s", vB.getOutputOfOp(), vA.getOutputOfOp());
|
Preconditions.checkState((vB.getOutputOfOp() == null && vA.getOutputOfOp() == null) || vB.getOutputOfOp().equals(vA.getOutputOfOp()), "%s - Output of op differ: %s vs. %s", s, vB.getOutputOfOp(), vA.getOutputOfOp());
|
||||||
|
|
||||||
Preconditions.checkState((vB.getControlDeps() == null) == (vA.getControlDeps() == null), "Control dependencies differ: %s vs. %s", vB.getControlDeps(), vA.getControlDeps());
|
equalConsideringNull(vB.getControlDeps(), vA.getControlDeps(), "%s - Control dependencies differ: %s vs. %s", s, vB.getControlDeps(), vA.getControlDeps());
|
||||||
Preconditions.checkState(vB.getControlDeps() == null || vB.getControlDeps().equals(vA.getControlDeps()), "Control dependencies differ: %s vs. %s", vB.getControlDeps(), vA.getControlDeps());
|
|
||||||
|
equalConsideringNull(vB.getControlDepsForOp(), vA.getControlDepsForOp(), "%s - Control dependencies for ops differ: %s vs. %s", s, vB.getControlDepsForOp(), vA.getControlDepsForOp());
|
||||||
|
|
||||||
|
equalConsideringNull(vB.getControlDepsForVar(), vA.getControlDepsForVar(), "%s - Control dependencies for vars differ: %s vs. %s", s, vB.getControlDepsForVar(), vA.getControlDepsForVar());
|
||||||
}
|
}
|
||||||
|
|
||||||
//Check loss variables:
|
//Check loss variables:
|
||||||
|
@ -343,51 +358,62 @@ public class OpValidation {
|
||||||
lossVarBefore, lossVarAfter);
|
lossVarBefore, lossVarAfter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(tc.fwdTestFns() != null && !tc.fwdTestFns().isEmpty()) {
|
||||||
|
//Finally: check execution/output
|
||||||
|
Map<String,INDArray> outOrig = original.outputAll(tc.placeholderValues());
|
||||||
|
Map<String,INDArray> outDe = deserialized.outputAll(tc.placeholderValues());
|
||||||
|
Preconditions.checkState(outOrig.keySet().equals(outDe.keySet()), "Keysets for execution after deserialization does not match key set for original model");
|
||||||
|
|
||||||
//Finally: check execution/output
|
for (String s : outOrig.keySet()) {
|
||||||
Map<String,INDArray> outOrig = original.outputAll(tc.placeholderValues());
|
INDArray orig = outOrig.get(s);
|
||||||
Map<String,INDArray> outDe = deserialized.outputAll(tc.placeholderValues());
|
INDArray deser = outDe.get(s);
|
||||||
Preconditions.checkState(outOrig.keySet().equals(outDe.keySet()), "Keysets for execution after deserialization does not match key set for original model");
|
|
||||||
|
|
||||||
for(String s : outOrig.keySet()){
|
Function<INDArray, String> f = tc.fwdTestFns().get(s);
|
||||||
INDArray orig = outOrig.get(s);
|
String err = null;
|
||||||
INDArray deser = outDe.get(s);
|
if (f != null) {
|
||||||
|
err = f.apply(deser);
|
||||||
Function<INDArray,String> f = tc.fwdTestFns().get(s);
|
} else {
|
||||||
String err = null;
|
if (!orig.equals(deser)) {
|
||||||
if(f != null){
|
//Edge case: check for NaNs in original and deserialized... might be legitimate test (like replaceNaNs op)
|
||||||
err = f.apply(deser);
|
long count = orig.dataType().isNumerical() ? Nd4j.getExecutioner().execAndReturn(new MatchCondition(orig, Conditions.isNan())).getFinalResult().longValue() : -1;
|
||||||
} else {
|
if (orig.dataType().isNumerical() && count > 0 && orig.equalShapes(deser)) {
|
||||||
if(!orig.equals(deser)){
|
long count2 = Nd4j.getExecutioner().execAndReturn(new MatchCondition(deser, Conditions.isNan())).getFinalResult().longValue();
|
||||||
//Edge case: check for NaNs in original and deserialized... might be legitimate test (like replaceNaNs op)
|
if (count != count2) {
|
||||||
long count = orig.dataType().isNumerical() ? Nd4j.getExecutioner().execAndReturn(new MatchCondition(orig, Conditions.isNan())).getFinalResult().longValue() : -1;
|
err = "INDArray equality failed";
|
||||||
if(orig.dataType().isNumerical() && count > 0 && orig.equalShapes(deser)){
|
} else {
|
||||||
long count2 = Nd4j.getExecutioner().execAndReturn(new MatchCondition(deser, Conditions.isNan())).getFinalResult().longValue();
|
//TODO is there a better way to do this?
|
||||||
if(count != count2){
|
NdIndexIterator iter = new NdIndexIterator(orig.shape());
|
||||||
err = "INDArray equality failed";
|
while (iter.hasNext()) {
|
||||||
} else {
|
long[] i = iter.next();
|
||||||
//TODO is there a better way to do this?
|
double d1 = orig.getDouble(i);
|
||||||
NdIndexIterator iter = new NdIndexIterator(orig.shape());
|
double d2 = deser.getDouble(i);
|
||||||
while(iter.hasNext()){
|
if ((Double.isNaN(d1) != Double.isNaN(d2)) || (Double.isInfinite(d1) != Double.isInfinite(d2)) || Math.abs(d1 - d2) > 1e-5) {
|
||||||
long[] i = iter.next();
|
err = "INDArray equality failed";
|
||||||
double d1 = orig.getDouble(i);
|
break;
|
||||||
double d2 = deser.getDouble(i);
|
}
|
||||||
if((Double.isNaN(d1) != Double.isNaN(d2)) || (Double.isInfinite(d1) != Double.isInfinite(d2)) || Math.abs(d1 - d2) > 1e-5 ){
|
|
||||||
err = "INDArray equality failed";
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
err = "INDArray equality failed";
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
err = "INDArray equality failed";
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Preconditions.checkState(err == null, "Variable result (%s) failed check - \"%ndSInfo\" vs \"%ndSInfo\" - %nd10 vs %nd10\nError:%s", s, orig, deser, orig, deser, err);
|
Preconditions.checkState(err == null, "Variable result (%s) failed check - \"%ndSInfo\" vs \"%ndSInfo\" - %nd10 vs %nd10\nError:%s", s, orig, deser, orig, deser, err);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected static void equalConsideringNull(List<String> l1, List<String> l2, String msg, Object... args){
|
||||||
|
//Consider null and length 0 list to be equal (semantically they mean the same thing)
|
||||||
|
boolean empty1 = l1 == null || l1.isEmpty();
|
||||||
|
boolean empty2 = l2 == null || l2.isEmpty();
|
||||||
|
if(empty1 && empty2){
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Preconditions.checkState(l1 == null || l1.equals(l2), msg, args);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Validate the outputs of a single op
|
* Validate the outputs of a single op
|
||||||
*
|
*
|
||||||
|
|
|
@ -25,6 +25,7 @@ public class NonInplaceValidationListener extends BaseListener {
|
||||||
private static AtomicInteger failCounter = new AtomicInteger();
|
private static AtomicInteger failCounter = new AtomicInteger();
|
||||||
|
|
||||||
protected INDArray[] opInputs;
|
protected INDArray[] opInputs;
|
||||||
|
protected INDArray[] opInputsOrig;
|
||||||
|
|
||||||
public NonInplaceValidationListener(){
|
public NonInplaceValidationListener(){
|
||||||
useCounter.getAndIncrement();
|
useCounter.getAndIncrement();
|
||||||
|
@ -42,14 +43,18 @@ public class NonInplaceValidationListener extends BaseListener {
|
||||||
//No input op
|
//No input op
|
||||||
return;
|
return;
|
||||||
} else if(o.y() == null){
|
} else if(o.y() == null){
|
||||||
|
opInputsOrig = new INDArray[]{o.x()};
|
||||||
opInputs = new INDArray[]{o.x().dup()};
|
opInputs = new INDArray[]{o.x().dup()};
|
||||||
} else {
|
} else {
|
||||||
|
opInputsOrig = new INDArray[]{o.x(), o.y()};
|
||||||
opInputs = new INDArray[]{o.x().dup(), o.y().dup()};
|
opInputs = new INDArray[]{o.x().dup(), o.y().dup()};
|
||||||
}
|
}
|
||||||
} else if(op.getOp() instanceof DynamicCustomOp){
|
} else if(op.getOp() instanceof DynamicCustomOp){
|
||||||
INDArray[] arr = ((DynamicCustomOp) op.getOp()).inputArguments();
|
INDArray[] arr = ((DynamicCustomOp) op.getOp()).inputArguments();
|
||||||
opInputs = new INDArray[arr.length];
|
opInputs = new INDArray[arr.length];
|
||||||
|
opInputsOrig = new INDArray[arr.length];
|
||||||
for( int i=0; i<arr.length; i++ ){
|
for( int i=0; i<arr.length; i++ ){
|
||||||
|
opInputsOrig[i] = arr[i];
|
||||||
opInputs[i] = arr[i].dup();
|
opInputs[i] = arr[i].dup();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -64,23 +69,6 @@ public class NonInplaceValidationListener extends BaseListener {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray[] inputsAfter;
|
|
||||||
if(op.getOp() instanceof Op){
|
|
||||||
Op o = (Op)op.getOp();
|
|
||||||
if(o.x() == null){
|
|
||||||
//No input op
|
|
||||||
return;
|
|
||||||
} else if(o.y() == null){
|
|
||||||
inputsAfter = new INDArray[]{o.x()};
|
|
||||||
} else {
|
|
||||||
inputsAfter = new INDArray[]{o.x(), o.y()};
|
|
||||||
}
|
|
||||||
} else if(op.getOp() instanceof DynamicCustomOp){
|
|
||||||
inputsAfter = ((DynamicCustomOp) op.getOp()).inputArguments();
|
|
||||||
} else {
|
|
||||||
throw new IllegalStateException("Unknown op type: " + op.getOp().getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
MessageDigest md;
|
MessageDigest md;
|
||||||
try {
|
try {
|
||||||
md = MessageDigest.getInstance("MD5");
|
md = MessageDigest.getInstance("MD5");
|
||||||
|
@ -93,12 +81,12 @@ public class NonInplaceValidationListener extends BaseListener {
|
||||||
|
|
||||||
//Need to hash - to ensure zero changes to input array
|
//Need to hash - to ensure zero changes to input array
|
||||||
byte[] before = opInputs[i].data().asBytes();
|
byte[] before = opInputs[i].data().asBytes();
|
||||||
INDArray after = inputsAfter[i];
|
INDArray after = this.opInputsOrig[i];
|
||||||
boolean dealloc = false;
|
boolean dealloc = false;
|
||||||
if(opInputs[i].ordering() != inputsAfter[i].ordering() || Arrays.equals(opInputs[i].stride(), inputsAfter[i].stride())
|
if(opInputs[i].ordering() != opInputsOrig[i].ordering() || Arrays.equals(opInputs[i].stride(), opInputsOrig[i].stride())
|
||||||
|| opInputs[i].elementWiseStride() != inputsAfter[i].elementWiseStride()){
|
|| opInputs[i].elementWiseStride() != opInputsOrig[i].elementWiseStride()){
|
||||||
//Clone if required (otherwise fails for views etc)
|
//Clone if required (otherwise fails for views etc)
|
||||||
after = inputsAfter[i].dup();
|
after = opInputsOrig[i].dup();
|
||||||
dealloc = true;
|
dealloc = true;
|
||||||
}
|
}
|
||||||
byte[] afterB = after.data().asBytes();
|
byte[] afterB = after.data().asBytes();
|
||||||
|
|
|
@ -67,29 +67,41 @@ public final class FlatNode extends Table {
|
||||||
public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); }
|
public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); }
|
||||||
public FlatArray scalar() { return scalar(new FlatArray()); }
|
public FlatArray scalar() { return scalar(new FlatArray()); }
|
||||||
public FlatArray scalar(FlatArray obj) { int o = __offset(40); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
|
public FlatArray scalar(FlatArray obj) { int o = __offset(40); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
|
||||||
|
public String controlDeps(int j) { int o = __offset(42); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
|
public int controlDepsLength() { int o = __offset(42); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public String varControlDeps(int j) { int o = __offset(44); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
|
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
|
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
|
||||||
public static int createFlatNode(FlatBufferBuilder builder,
|
public static int createFlatNode(FlatBufferBuilder builder,
|
||||||
int id,
|
int id,
|
||||||
int nameOffset,
|
int nameOffset,
|
||||||
byte opType,
|
byte opType,
|
||||||
long opNum,
|
long opNum,
|
||||||
int propertiesOffset,
|
int propertiesOffset,
|
||||||
int inputOffset,
|
int inputOffset,
|
||||||
int inputPairedOffset,
|
int inputPairedOffset,
|
||||||
int outputOffset,
|
int outputOffset,
|
||||||
int extraParamsOffset,
|
int extraParamsOffset,
|
||||||
int extraIntegerOffset,
|
int extraIntegerOffset,
|
||||||
int extraBoolsOffset,
|
int extraBoolsOffset,
|
||||||
int dimensionsOffset,
|
int dimensionsOffset,
|
||||||
int device,
|
int device,
|
||||||
int scope_id,
|
int scope_id,
|
||||||
int scope_nameOffset,
|
int scope_nameOffset,
|
||||||
int outputNamesOffset,
|
int outputNamesOffset,
|
||||||
int opNameOffset,
|
int opNameOffset,
|
||||||
int outputTypesOffset,
|
int outputTypesOffset,
|
||||||
int scalarOffset) {
|
int scalarOffset,
|
||||||
builder.startObject(19);
|
int controlDepsOffset,
|
||||||
|
int varControlDepsOffset,
|
||||||
|
int controlDepForOffset) {
|
||||||
|
builder.startObject(22);
|
||||||
FlatNode.addOpNum(builder, opNum);
|
FlatNode.addOpNum(builder, opNum);
|
||||||
|
FlatNode.addControlDepFor(builder, controlDepForOffset);
|
||||||
|
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
|
||||||
|
FlatNode.addControlDeps(builder, controlDepsOffset);
|
||||||
FlatNode.addScalar(builder, scalarOffset);
|
FlatNode.addScalar(builder, scalarOffset);
|
||||||
FlatNode.addOutputTypes(builder, outputTypesOffset);
|
FlatNode.addOutputTypes(builder, outputTypesOffset);
|
||||||
FlatNode.addOpName(builder, opNameOffset);
|
FlatNode.addOpName(builder, opNameOffset);
|
||||||
|
@ -111,7 +123,7 @@ public final class FlatNode extends Table {
|
||||||
return FlatNode.endFlatNode(builder);
|
return FlatNode.endFlatNode(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(19); }
|
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); }
|
||||||
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
||||||
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
||||||
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
||||||
|
@ -151,6 +163,15 @@ public final class FlatNode extends Table {
|
||||||
public static int createOutputTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
|
public static int createOutputTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
|
||||||
public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
|
public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
|
||||||
public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); }
|
public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); }
|
||||||
|
public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(19, controlDepsOffset, 0); }
|
||||||
|
public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
|
public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
|
public static void addVarControlDeps(FlatBufferBuilder builder, int varControlDepsOffset) { builder.addOffset(20, varControlDepsOffset, 0); }
|
||||||
|
public static int createVarControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
|
public static void startVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
|
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
|
||||||
|
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
|
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
public static int endFlatNode(FlatBufferBuilder builder) {
|
public static int endFlatNode(FlatBufferBuilder builder) {
|
||||||
int o = builder.endObject();
|
int o = builder.endObject();
|
||||||
return o;
|
return o;
|
||||||
|
|
|
@ -29,16 +29,28 @@ public final class FlatVariable extends Table {
|
||||||
public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
|
public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
|
||||||
public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
|
public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
|
||||||
public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; }
|
public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; }
|
||||||
|
public String controlDeps(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
|
public int controlDepsLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public String controlDepForOp(int j) { int o = __offset(20); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
|
public int controlDepForOpLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
public String controlDepsForVar(int j) { int o = __offset(22); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||||
|
public int controlDepsForVarLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; }
|
||||||
|
|
||||||
public static int createFlatVariable(FlatBufferBuilder builder,
|
public static int createFlatVariable(FlatBufferBuilder builder,
|
||||||
int idOffset,
|
int idOffset,
|
||||||
int nameOffset,
|
int nameOffset,
|
||||||
byte dtype,
|
byte dtype,
|
||||||
int shapeOffset,
|
int shapeOffset,
|
||||||
int ndarrayOffset,
|
int ndarrayOffset,
|
||||||
int device,
|
int device,
|
||||||
byte variabletype) {
|
byte variabletype,
|
||||||
builder.startObject(7);
|
int controlDepsOffset,
|
||||||
|
int controlDepForOpOffset,
|
||||||
|
int controlDepsForVarOffset) {
|
||||||
|
builder.startObject(10);
|
||||||
|
FlatVariable.addControlDepsForVar(builder, controlDepsForVarOffset);
|
||||||
|
FlatVariable.addControlDepForOp(builder, controlDepForOpOffset);
|
||||||
|
FlatVariable.addControlDeps(builder, controlDepsOffset);
|
||||||
FlatVariable.addDevice(builder, device);
|
FlatVariable.addDevice(builder, device);
|
||||||
FlatVariable.addNdarray(builder, ndarrayOffset);
|
FlatVariable.addNdarray(builder, ndarrayOffset);
|
||||||
FlatVariable.addShape(builder, shapeOffset);
|
FlatVariable.addShape(builder, shapeOffset);
|
||||||
|
@ -49,7 +61,7 @@ public final class FlatVariable extends Table {
|
||||||
return FlatVariable.endFlatVariable(builder);
|
return FlatVariable.endFlatVariable(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(7); }
|
public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(10); }
|
||||||
public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); }
|
public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); }
|
||||||
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
||||||
public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); }
|
public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); }
|
||||||
|
@ -59,6 +71,15 @@ public final class FlatVariable extends Table {
|
||||||
public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); }
|
public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); }
|
||||||
public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); }
|
public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); }
|
||||||
public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); }
|
public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); }
|
||||||
|
public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(7, controlDepsOffset, 0); }
|
||||||
|
public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
|
public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
|
public static void addControlDepForOp(FlatBufferBuilder builder, int controlDepForOpOffset) { builder.addOffset(8, controlDepForOpOffset, 0); }
|
||||||
|
public static int createControlDepForOpVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
|
public static void startControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
|
public static void addControlDepsForVar(FlatBufferBuilder builder, int controlDepsForVarOffset) { builder.addOffset(9, controlDepsForVarOffset, 0); }
|
||||||
|
public static int createControlDepsForVarVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||||
|
public static void startControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||||
public static int endFlatVariable(FlatBufferBuilder builder) {
|
public static int endFlatVariable(FlatBufferBuilder builder) {
|
||||||
int o = builder.endObject();
|
int o = builder.endObject();
|
||||||
return o;
|
return o;
|
||||||
|
@ -67,3 +88,4 @@ public final class FlatVariable extends Table {
|
||||||
public static void finishSizePrefixedFlatVariableBuffer(FlatBufferBuilder builder, int offset) { builder.finishSizePrefixed(offset); }
|
public static void finishSizePrefixedFlatVariableBuffer(FlatBufferBuilder builder, int offset) { builder.finishSizePrefixed(offset); }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,11 +25,7 @@ import org.nd4j.imports.descriptors.onnx.OnnxDescriptorParser;
|
||||||
import org.nd4j.imports.descriptors.onnx.OpDescriptor;
|
import org.nd4j.imports.descriptors.onnx.OpDescriptor;
|
||||||
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
|
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
|
||||||
import org.nd4j.linalg.api.ops.*;
|
import org.nd4j.linalg.api.ops.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.*;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.*;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
@ -370,6 +366,8 @@ public class DifferentialFunctionClassHolder {
|
||||||
return Merge.class;
|
return Merge.class;
|
||||||
case Switch.OP_NAME:
|
case Switch.OP_NAME:
|
||||||
return Switch.class;
|
return Switch.class;
|
||||||
|
case LoopCond.OP_NAME:
|
||||||
|
return LoopCond.class;
|
||||||
case ExternalErrorsFunction.OP_NAME:
|
case ExternalErrorsFunction.OP_NAME:
|
||||||
return ExternalErrorsFunction.class;
|
return ExternalErrorsFunction.class;
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -69,13 +69,9 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan.class,
|
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan.class,
|
||||||
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThanOrEqual.class,
|
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThanOrEqual.class,
|
||||||
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastNotEqual.class,
|
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastNotEqual.class,
|
||||||
org.nd4j.linalg.api.ops.impl.controlflow.If.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.controlflow.IfDerivative.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.controlflow.Select.class,
|
org.nd4j.linalg.api.ops.impl.controlflow.Select.class,
|
||||||
org.nd4j.linalg.api.ops.impl.controlflow.Where.class,
|
org.nd4j.linalg.api.ops.impl.controlflow.Where.class,
|
||||||
org.nd4j.linalg.api.ops.impl.controlflow.WhereNumpy.class,
|
org.nd4j.linalg.api.ops.impl.controlflow.WhereNumpy.class,
|
||||||
org.nd4j.linalg.api.ops.impl.controlflow.While.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.controlflow.WhileDerivative.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter.class,
|
org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter.class,
|
||||||
org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit.class,
|
org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit.class,
|
||||||
org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond.class,
|
org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond.class,
|
||||||
|
|
|
@ -1,413 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.imports.graphmapper;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
|
||||||
import org.nd4j.shade.protobuf.Message;
|
|
||||||
import org.nd4j.shade.protobuf.TextFormat;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import lombok.val;
|
|
||||||
import org.apache.commons.io.IOUtils;
|
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.autodiff.samediff.VariableType;
|
|
||||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
|
||||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Base implementation for importing a graph
|
|
||||||
*
|
|
||||||
* @param <GRAPH_TYPE> the type of graph
|
|
||||||
* @param <NODE_TYPE> the type of node
|
|
||||||
* @param <ATTR_TYPE> the attribute type
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public abstract class BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE> implements GraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE> {
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Op.Type opTypeForNode(NODE_TYPE nodeDef) {
|
|
||||||
DifferentialFunction opWithTensorflowName = getMappedOp(getOpType(nodeDef));
|
|
||||||
if (opWithTensorflowName == null)
|
|
||||||
throw new NoOpNameFoundException("No op found with name " + getOpType(nodeDef));
|
|
||||||
Op.Type type = opWithTensorflowName.opType();
|
|
||||||
return type;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappings) {
|
|
||||||
val mappings = propertyMappings.get(getOpType(node));
|
|
||||||
if (mappings == null || mappings.isEmpty()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
for (val entry : mappings.entrySet()) {
|
|
||||||
mapProperty(entry.getKey(), on, node, graph, sameDiff, propertyMappings);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param inputStream
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public SameDiff importGraph(InputStream inputStream) {
|
|
||||||
return importGraph(inputStream, Collections.<String, OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>>emptyMap(), null);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiff importGraph(InputStream inputStream, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
|
|
||||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter) {
|
|
||||||
GRAPH_TYPE def = readGraph(inputStream, opImportOverrides);
|
|
||||||
return importGraph(def, opImportOverrides, opFilter);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected GRAPH_TYPE readGraph(InputStream inputStream, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides) {
|
|
||||||
byte[] bytes = null;
|
|
||||||
GRAPH_TYPE def = null;
|
|
||||||
try {
|
|
||||||
bytes = IOUtils.toByteArray(inputStream); //Buffers internally
|
|
||||||
def = parseGraphFrom(bytes);
|
|
||||||
} catch (IOException e) {
|
|
||||||
try (BufferedInputStream bis2 = new BufferedInputStream(new ByteArrayInputStream(bytes)); BufferedReader reader = new BufferedReader(new InputStreamReader(bis2))) {
|
|
||||||
Message.Builder builder = getNewGraphBuilder();
|
|
||||||
|
|
||||||
StringBuilder str = new StringBuilder();
|
|
||||||
String line = null;
|
|
||||||
while ((line = reader.readLine()) != null) {
|
|
||||||
str.append(line);//.append("\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
TextFormat.getParser().merge(str.toString(), builder);
|
|
||||||
def = (GRAPH_TYPE) builder.build();
|
|
||||||
} catch (Exception e2) {
|
|
||||||
e2.printStackTrace();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return def;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param graphFile
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public SameDiff importGraph(File graphFile) {
|
|
||||||
return importGraph(graphFile, Collections.<String, OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>>emptyMap(), null);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiff importGraph(File graphFile, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
|
|
||||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter) {
|
|
||||||
GRAPH_TYPE def = null;
|
|
||||||
try (FileInputStream fis = new FileInputStream(graphFile)) {
|
|
||||||
return importGraph(fis, opImportOverrides, opFilter);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new ND4JIllegalStateException("Error encountered loading graph file: " + graphFile.getAbsolutePath(), e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, NODE_TYPE> nameIndexForGraph(GRAPH_TYPE graph) {
|
|
||||||
List<NODE_TYPE> nodes = getNodeList(graph);
|
|
||||||
Map<String, NODE_TYPE> ret = new HashMap<>();
|
|
||||||
for (NODE_TYPE node : nodes) {
|
|
||||||
ret.put(getName(node), node);
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, NODE_TYPE> nodesByName(GRAPH_TYPE graph) {
|
|
||||||
val nodeTypes = getNodeList(graph);
|
|
||||||
Map<String, NODE_TYPE> ret = new LinkedHashMap<>();
|
|
||||||
for (int i = 0; i < nodeTypes.size(); i++) {
|
|
||||||
ret.put(getName(nodeTypes.get(i)), nodeTypes.get(i));
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method converts given TF
|
|
||||||
*
|
|
||||||
* @param tfGraph
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public SameDiff importGraph(GRAPH_TYPE tfGraph) {
|
|
||||||
return importGraph(tfGraph, Collections.<String, OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>>emptyMap(), null);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiff importGraph(GRAPH_TYPE tfGraph, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
|
|
||||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter) {
|
|
||||||
|
|
||||||
SameDiff diff = SameDiff.create();
|
|
||||||
ImportState<GRAPH_TYPE, TENSOR_TYPE> importState = new ImportState<>();
|
|
||||||
importState.setSameDiff(diff);
|
|
||||||
importState.setGraph(tfGraph);
|
|
||||||
|
|
||||||
Map<String,TENSOR_TYPE> variablesForGraph = variablesForGraph(tfGraph);
|
|
||||||
importState.setVariables(variablesForGraph);
|
|
||||||
|
|
||||||
|
|
||||||
//Add each of the variables first - before importing ops
|
|
||||||
Map<String, Boolean> stringNodes = new HashMap<>(); //Key: name of string variable. Value: if it's a constant
|
|
||||||
for (Map.Entry<String, TENSOR_TYPE> entry : variablesForGraph.entrySet()) {
|
|
||||||
if (shouldSkip((NODE_TYPE) entry.getValue())) { //TODO only works for TF
|
|
||||||
//Skip some nodes, for example reduction indices (a lot of ND4J/SameDiff ops use int[] etc, not an INDArray/Variable)
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
//First: check if we're skipping the op entirely. If so: don't create the output variables for it.
|
|
||||||
NODE_TYPE node = (NODE_TYPE) entry.getValue(); //TODO this only works for TF
|
|
||||||
String opType = getOpType(node);
|
|
||||||
String opName = getName(node);
|
|
||||||
if(opFilter != null && opFilter.skipOp(node, importState.getSameDiff(), null, importState.getGraph() )){
|
|
||||||
log.info("Skipping variables for op: {} (name: {})", opType, opName);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
//Similarly, if an OpImportOverride is defined, don't create the variables now, as these might be the wrong type
|
|
||||||
//For example, the OpImportOverride might replace the op with some placeholders
|
|
||||||
// If we simply created them now, we'd create the wrong type (Array not placeholder)
|
|
||||||
if(opImportOverrides != null && opImportOverrides.containsKey(opType)){
|
|
||||||
log.info("Skipping variables for op due to presence of OpImportOverride: {} (name: {})", opType, opName);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
DataType dt = dataTypeForTensor(entry.getValue(), 0);
|
|
||||||
INDArray arr = getNDArrayFromTensor(entry.getKey(), entry.getValue(), tfGraph);
|
|
||||||
long[] shape = hasShape((NODE_TYPE) entry.getValue()) ? getShape((NODE_TYPE) entry.getValue()) : null; //TODO only works for TF
|
|
||||||
|
|
||||||
//Not all variables have datatypes available on import - we have to infer these at a later point
|
|
||||||
// so we'll leave datatypes as null and infer them once all variables/ops have been imported
|
|
||||||
if(dt == DataType.UNKNOWN)
|
|
||||||
dt = null;
|
|
||||||
|
|
||||||
if (isPlaceHolder(entry.getValue())) {
|
|
||||||
diff.placeHolder(entry.getKey(), dt, shape);
|
|
||||||
} else if (isConstant(entry.getValue())) {
|
|
||||||
Preconditions.checkNotNull(arr, "Array is null for placeholder variable %s", entry.getKey());
|
|
||||||
diff.constant(entry.getKey(), arr);
|
|
||||||
} else {
|
|
||||||
//Could be variable, or could be array type (i.e., output of op/"activations")
|
|
||||||
//TODO work out which!
|
|
||||||
|
|
||||||
SDVariable v;
|
|
||||||
if(shape == null || ArrayUtil.contains(shape, 0)){
|
|
||||||
//No shape, or 0 in shape -> probably not a variable...
|
|
||||||
v = diff.var(entry.getKey(), VariableType.ARRAY, null, dt, (long[])null);
|
|
||||||
} else {
|
|
||||||
v = diff.var(entry.getKey(), dt, shape);
|
|
||||||
}
|
|
||||||
if (arr != null)
|
|
||||||
diff.associateArrayWithVariable(arr, v);
|
|
||||||
}
|
|
||||||
|
|
||||||
// NODE_TYPE node = (NODE_TYPE) entry.getValue(); //TODO this only works for TF
|
|
||||||
List<String> controlDependencies = getControlDependencies(node);
|
|
||||||
if (controlDependencies != null) {
|
|
||||||
Variable v = diff.getVariables().get(entry.getKey());
|
|
||||||
v.setControlDeps(controlDependencies);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//Map ops
|
|
||||||
val tfNodesList = getNodeList(tfGraph);
|
|
||||||
for (NODE_TYPE node : tfNodesList) {
|
|
||||||
String opType = getOpType(node);
|
|
||||||
OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> importOverride = null;
|
|
||||||
if(opImportOverrides != null){
|
|
||||||
importOverride = opImportOverrides.get(opType);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(opFilter != null && opFilter.skipOp(node, importState.getSameDiff(), null, null)){
|
|
||||||
String opName = getName(node);
|
|
||||||
log.info("Skipping op due to op filter: {}", opType, opName);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!opsToIgnore().contains(opType) || isOpIgnoreException(node)) {
|
|
||||||
mapNodeType(node, importState, importOverride, opFilter);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
At this point, we have a few remaining things to do:
|
|
||||||
1. Make sure all datatypes are set on all variables. TF doesn't have datatype info an all op outputs for some reason, so we have to infer in manually
|
|
||||||
2. Make sure all op output variables have been created
|
|
||||||
3. Make sure all SameDiffOp.outputsOfOp is set
|
|
||||||
4. Make sure all Variable.outputOfOp is set
|
|
||||||
5. Make sure all Variable.controlDepsForVar have been populated (reverse lookup of Variable.controlDeps)
|
|
||||||
*/
|
|
||||||
|
|
||||||
//Make sure Variable.outputOfOp is set
|
|
||||||
for(Variable v : diff.getVariables().values()){
|
|
||||||
if(v.getVariable().isPlaceHolder() || v.getVariable().isConstant())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
//Expect variable names of output variables to be: opName, opName:1, opName:2, etc
|
|
||||||
String n = v.getName();
|
|
||||||
String opName = n;
|
|
||||||
if(v.getName().matches(".*:\\d+")){
|
|
||||||
//i.e., "something:2"
|
|
||||||
int idx = n.lastIndexOf(':');
|
|
||||||
opName = n.substring(0,idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(diff.getOps().containsKey(opName)) {
|
|
||||||
//Variable is the output of an op
|
|
||||||
v.setOutputOfOp(opName);
|
|
||||||
|
|
||||||
//Also double check variable type...
|
|
||||||
if(v.getVariable().getVariableType() != VariableType.ARRAY)
|
|
||||||
v.getVariable().setVariableType(VariableType.ARRAY);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//Initialize any missing output variables
|
|
||||||
for (SameDiffOp op : diff.getOps().values()) {
|
|
||||||
DifferentialFunction df = op.getOp();
|
|
||||||
initOutputVariables(diff, df);
|
|
||||||
}
|
|
||||||
|
|
||||||
//Make sure all Variable.controlDepsForVar have been populated (reverse lookup of Variable.controlDeps)
|
|
||||||
//i.e., if control dependency x -> y exists, then:
|
|
||||||
// (a) x.controlDepsForVar should contain "y"
|
|
||||||
// (b) y.controlDeps should contain "x"
|
|
||||||
//Need to do this before output datatype calculation, as these control dep info is used in sessions
|
|
||||||
for(Map.Entry<String,Variable> e : diff.getVariables().entrySet()){
|
|
||||||
Variable v = e.getValue();
|
|
||||||
if(v.getControlDeps() != null){
|
|
||||||
for(String s : v.getControlDeps()){
|
|
||||||
Variable v2 = diff.getVariables().get(s);
|
|
||||||
if(v2.getControlDepsForVar() == null)
|
|
||||||
v2.setControlDepsForVar(new ArrayList<String>());
|
|
||||||
if(!v2.getControlDepsForVar().contains(e.getKey())){
|
|
||||||
//Control dep v2 -> v exists, so put v.name into v2.controlDepsForVar
|
|
||||||
v2.getControlDepsForVar().add(e.getKey());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//Same thing for op control dependencies...
|
|
||||||
for(Map.Entry<String,SameDiffOp> e : diff.getOps().entrySet()){
|
|
||||||
SameDiffOp op = e.getValue();
|
|
||||||
if(op.getControlDeps() != null){
|
|
||||||
for(String s : op.getControlDeps()){
|
|
||||||
//Control dependency varS -> op exists
|
|
||||||
Variable v = diff.getVariables().get(s);
|
|
||||||
if(v.getControlDepsForOp() == null)
|
|
||||||
v.setControlDepsForOp(new ArrayList<String>());
|
|
||||||
if(!v.getControlDepsForOp().contains(e.getKey()))
|
|
||||||
v.getControlDepsForOp().add(e.getKey());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//Infer variable datatypes to ensure all variables have datatypes...
|
|
||||||
boolean anyUnknown = false;
|
|
||||||
for(SDVariable v : diff.variables()){
|
|
||||||
if(v.dataType() == null)
|
|
||||||
anyUnknown = true;
|
|
||||||
}
|
|
||||||
if(anyUnknown){
|
|
||||||
Map<String,DataType> dataTypes = diff.calculateOutputDataTypes();
|
|
||||||
for(SDVariable v : diff.variables()){
|
|
||||||
if(v.dataType() == null){
|
|
||||||
v.setDataType(dataTypes.get(v.getVarName()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//Validate the graph structure
|
|
||||||
validateGraphStructure(diff);
|
|
||||||
|
|
||||||
return diff;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void initOutputVariables(SameDiff sd, DifferentialFunction df) {
|
|
||||||
String[] outNames = sd.getOutputsForOp(df);
|
|
||||||
SDVariable[] outVars;
|
|
||||||
if (outNames == null) {
|
|
||||||
outVars = sd.generateOutputVariableForOp(df, df.getOwnName() != null ? df.getOwnName() : df.opName(), true);
|
|
||||||
outNames = new String[outVars.length];
|
|
||||||
for (int i = 0; i < outVars.length; i++) {
|
|
||||||
outNames[i] = outVars[i].getVarName();
|
|
||||||
}
|
|
||||||
sd.getOps().get(df.getOwnName()).setOutputsOfOp(Arrays.asList(outNames));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (String s : outNames) {
|
|
||||||
sd.getVariables().get(s).setOutputOfOp(df.getOwnName());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean validTensorDataType(TENSOR_TYPE tensorType) {
|
|
||||||
return dataTypeForTensor(tensorType, 0) != DataType.UNKNOWN;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void validateGraphStructure(SameDiff sameDiff) {
|
|
||||||
//First: Check placeholders. When SDVariables are added with null shapes, these can be interpreted as a placeholder
|
|
||||||
// but null shapes might simply mean shape isn't available during import right when the variable is added
|
|
||||||
//Idea here: if a "placeholder" is the output of any function, it's not really a placeholder
|
|
||||||
for (SDVariable v : sameDiff.variables()) {
|
|
||||||
String name = v.getVarName();
|
|
||||||
if (sameDiff.isPlaceHolder(name)) {
|
|
||||||
String varOutputOf = sameDiff.getVariables().get(name).getOutputOfOp();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//Second: check that all op inputs actually exist in the graph
|
|
||||||
for (SameDiffOp op : sameDiff.getOps().values()) {
|
|
||||||
List<String> inputs = op.getInputsToOp();
|
|
||||||
if (inputs == null)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
for (String s : inputs) {
|
|
||||||
if (sameDiff.getVariable(s) == null) {
|
|
||||||
throw new IllegalStateException("Import validation failed: op \"" + op.getName() + "\" of type " + op.getOp().getClass().getSimpleName()
|
|
||||||
+ " has input \"" + s + "\" that does not have a corresponding variable in the graph");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,429 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.imports.graphmapper;
|
|
||||||
|
|
||||||
import org.nd4j.shade.protobuf.Message;
|
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Map graph proto types to
|
|
||||||
*
|
|
||||||
* {@link SameDiff} instances
|
|
||||||
* @param <GRAPH_TYPE> the proto type for the graph
|
|
||||||
* @param <NODE_TYPE> the proto type for the node
|
|
||||||
* @param <ATTR_TYPE> the proto type for the attribute
|
|
||||||
* @param <TENSOR_TYPE> the proto type for the tensor
|
|
||||||
*@author Adam Gibson
|
|
||||||
*/
|
|
||||||
public interface GraphMapper<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE,TENSOR_TYPE> {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Import a graph as SameDiff from the given file
|
|
||||||
* @param graphFile Input stream pointing to graph file to import
|
|
||||||
* @return Imported graph
|
|
||||||
*/
|
|
||||||
SameDiff importGraph(InputStream graphFile);
|
|
||||||
|
|
||||||
SameDiff importGraph(InputStream graphFile, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
|
|
||||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Import a graph as SameDiff from the given file
|
|
||||||
* @param graphFile Graph file to import
|
|
||||||
* @return Imported graph
|
|
||||||
* @see #importGraph(File, Map)
|
|
||||||
*/
|
|
||||||
SameDiff importGraph(File graphFile);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Import a graph as SameDiff from the given file, with optional op import overrides.<br>
|
|
||||||
* The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops
|
|
||||||
* that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions.
|
|
||||||
*
|
|
||||||
* @param graphFile Graph file to import
|
|
||||||
* @param opImportOverrides May be null. If non-null: used to import the specified operations. Key is the name of the
|
|
||||||
* operation to import, value is the object used to import it
|
|
||||||
* @return Imported graph
|
|
||||||
*/
|
|
||||||
SameDiff importGraph(File graphFile, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
|
|
||||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method converts given graph type (in its native format) to SameDiff
|
|
||||||
* @param graph Graph to import
|
|
||||||
* @return Imported graph
|
|
||||||
*/
|
|
||||||
SameDiff importGraph(GRAPH_TYPE graph);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method converts given graph type (in its native format) to SameDiff<br>
|
|
||||||
* The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops
|
|
||||||
* that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions.
|
|
||||||
* @param graph Graph to import
|
|
||||||
* @return Imported graph
|
|
||||||
*/
|
|
||||||
SameDiff importGraph(GRAPH_TYPE graph, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
|
|
||||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns true if this node is a special case
|
|
||||||
* (maybe because of name or other scenarios)
|
|
||||||
* that should override {@link #opsToIgnore()}
|
|
||||||
* in certain circumstances
|
|
||||||
* @param node the node to check
|
|
||||||
* @return true if this node is an exception false otherwise
|
|
||||||
*/
|
|
||||||
boolean isOpIgnoreException(NODE_TYPE node);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the nodes sorted by n ame
|
|
||||||
* from a given graph
|
|
||||||
* @param graph the graph to get the nodes for
|
|
||||||
* @return the map of the nodes by name
|
|
||||||
* for a given graph
|
|
||||||
*/
|
|
||||||
Map<String,NODE_TYPE> nodesByName(GRAPH_TYPE graph);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the target mapping key (usually based on the node name)
|
|
||||||
* for the given function
|
|
||||||
* @param function the function
|
|
||||||
* @param node the node to derive the target mapping from
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
String getTargetMappingForOp(DifferentialFunction function, NODE_TYPE node);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param on
|
|
||||||
* @param node
|
|
||||||
* @param graph
|
|
||||||
* @param sameDiff
|
|
||||||
* @param propertyMappings
|
|
||||||
*/
|
|
||||||
void mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappings);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param name
|
|
||||||
* @param on
|
|
||||||
* @param node
|
|
||||||
* @param graph
|
|
||||||
* @param sameDiff
|
|
||||||
* @param propertyMappingsForFunction
|
|
||||||
*/
|
|
||||||
void mapProperty(String name, DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the node from the graph
|
|
||||||
* @param graph the graph to get the node from
|
|
||||||
* @param name the name of the node to get from the graph
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
NODE_TYPE getNodeWithNameFromGraph(GRAPH_TYPE graph,String name);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns true if the given node is a place holder
|
|
||||||
* @param node the node to check
|
|
||||||
* @return true if the node is a place holder or not
|
|
||||||
*/
|
|
||||||
boolean isPlaceHolderNode(TENSOR_TYPE node);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the list of control dependencies for the current node (or null if none exist)
|
|
||||||
*
|
|
||||||
* @param node Node to get the control dependencies (if any) for
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
List<String> getControlDependencies(NODE_TYPE node);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dump a binary proto file representation as a
|
|
||||||
* plain string in to the target text file
|
|
||||||
* @param inputFile
|
|
||||||
* @param outputFile
|
|
||||||
*/
|
|
||||||
void dumpBinaryProtoAsText(File inputFile,File outputFile);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dump a binary proto file representation as a
|
|
||||||
* plain string in to the target text file
|
|
||||||
* @param inputFile
|
|
||||||
* @param outputFile
|
|
||||||
*/
|
|
||||||
void dumpBinaryProtoAsText(InputStream inputFile,File outputFile);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the mapped op name
|
|
||||||
* for a given op
|
|
||||||
* relative to the type of node being mapped.
|
|
||||||
* The input name should be based on a tensorflow
|
|
||||||
* type or onnx type, not the nd4j name
|
|
||||||
* @param name the tensorflow or onnx name
|
|
||||||
* @return the function based on the values in
|
|
||||||
* {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder}
|
|
||||||
*/
|
|
||||||
DifferentialFunction getMappedOp(String name);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the variables for the given graph
|
|
||||||
* @param graphType the graph to get the variables for
|
|
||||||
* @return a map of variable name to tensor
|
|
||||||
*/
|
|
||||||
Map<String,TENSOR_TYPE> variablesForGraph(GRAPH_TYPE graphType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param name
|
|
||||||
* @param node
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
String translateToSameDiffName(String name, NODE_TYPE node);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param graph
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
Map<String,NODE_TYPE> nameIndexForGraph(GRAPH_TYPE graph);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns an op type for the given input node
|
|
||||||
* @param nodeType the node to use
|
|
||||||
* @return the optype for the given node
|
|
||||||
*/
|
|
||||||
Op.Type opTypeForNode(NODE_TYPE nodeType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a graph builder for initial definition and parsing.
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
Message.Builder getNewGraphBuilder();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse a graph from an input stream
|
|
||||||
* @param inputStream the input stream to load from
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
GRAPH_TYPE parseGraphFrom(byte[] inputStream) throws IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse a graph from an input stream
|
|
||||||
* @param inputStream the input stream to load from
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
GRAPH_TYPE parseGraphFrom(InputStream inputStream) throws IOException;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Map a node in to the import state covering the {@link SameDiff} instance
|
|
||||||
* @param tfNode the node to map
|
|
||||||
* @param importState the current import state
|
|
||||||
* @param opFilter Optional filter for skipping operations
|
|
||||||
*/
|
|
||||||
void mapNodeType(NODE_TYPE tfNode, ImportState<GRAPH_TYPE,TENSOR_TYPE> importState,
|
|
||||||
OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opImportOverride,
|
|
||||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param tensorType
|
|
||||||
* @param outputNum
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
DataType dataTypeForTensor(TENSOR_TYPE tensorType, int outputNum);
|
|
||||||
|
|
||||||
boolean isStringType(TENSOR_TYPE tensor);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param nodeType
|
|
||||||
* @param key
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
String getAttrValueFromNode(NODE_TYPE nodeType,String key);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param attrType
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
long[] getShapeFromAttribute(ATTR_TYPE attrType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns true if the given node is a place holder type
|
|
||||||
* (think a yet to be determined shape)_
|
|
||||||
* @param nodeType
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean isPlaceHolder(TENSOR_TYPE nodeType);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns true if the given node is a constant
|
|
||||||
* @param nodeType
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean isConstant(TENSOR_TYPE nodeType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* @param tensorName
|
|
||||||
* @param tensorType
|
|
||||||
* @param graph
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
INDArray getNDArrayFromTensor(String tensorName, TENSOR_TYPE tensorType, GRAPH_TYPE graph);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the shape for the given tensor type
|
|
||||||
* @param tensorType
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
long[] getShapeFromTensor(TENSOR_TYPE tensorType);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Ops to ignore for mapping
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
Set<String> opsToIgnore();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the input node for the given node
|
|
||||||
* @param node the node
|
|
||||||
* @param index hte index
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
String getInputFromNode(NODE_TYPE node, int index);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the number of inputs for a node.
|
|
||||||
* @param nodeType the node to get the number of inputs for
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
int numInputsFor(NODE_TYPE nodeType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Whether the data type for the tensor is valid
|
|
||||||
* for creating an {@link INDArray}
|
|
||||||
* @param tensorType the tensor proto to test
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean validTensorDataType(TENSOR_TYPE tensorType);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the shape of the attribute value
|
|
||||||
* @param attr the attribute value
|
|
||||||
* @return the shape of the attribute if any or null
|
|
||||||
*/
|
|
||||||
long[] getShapeFromAttr(ATTR_TYPE attr);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the attribute
|
|
||||||
* map for given node
|
|
||||||
* @param nodeType the node
|
|
||||||
* @return the attribute map for the attribute
|
|
||||||
*/
|
|
||||||
Map<String,ATTR_TYPE> getAttrMap(NODE_TYPE nodeType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the name of the node
|
|
||||||
* @param nodeType the node
|
|
||||||
* to get the name for
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
String getName(NODE_TYPE nodeType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param nodeType
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean alreadySeen(NODE_TYPE nodeType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param nodeType
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean isVariableNode(NODE_TYPE nodeType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* @param opType
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean shouldSkip(NODE_TYPE opType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param nodeType
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean hasShape(NODE_TYPE nodeType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param nodeType
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
long[] getShape(NODE_TYPE nodeType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param nodeType
|
|
||||||
* @param graph
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
INDArray getArrayFrom(NODE_TYPE nodeType, GRAPH_TYPE graph);
|
|
||||||
|
|
||||||
|
|
||||||
String getOpType(NODE_TYPE nodeType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param graphType
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
List<NODE_TYPE> getNodeList(GRAPH_TYPE graphType);
|
|
||||||
}
|
|
|
@ -1,31 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.imports.graphmapper;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class ImportState<GRAPH_TYPE,TENSOR_TYPE> {
|
|
||||||
private SameDiff sameDiff;
|
|
||||||
private GRAPH_TYPE graph;
|
|
||||||
private Map<String,TENSOR_TYPE> variables;
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,652 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.imports.graphmapper.onnx;
|
|
||||||
|
|
||||||
import org.nd4j.shade.protobuf.ByteString;
|
|
||||||
import org.nd4j.shade.protobuf.Message;
|
|
||||||
import org.nd4j.shade.guava.primitives.Floats;
|
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
|
||||||
import org.nd4j.shade.guava.primitives.Longs;
|
|
||||||
import lombok.val;
|
|
||||||
import onnx.Onnx;
|
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
|
||||||
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
|
||||||
import org.nd4j.imports.graphmapper.BaseGraphMapper;
|
|
||||||
import org.nd4j.imports.graphmapper.ImportState;
|
|
||||||
import org.nd4j.imports.graphmapper.OpImportFilter;
|
|
||||||
import org.nd4j.imports.graphmapper.OpImportOverride;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.nio.ByteOrder;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A mapper for onnx graphs to
|
|
||||||
* {@link org.nd4j.autodiff.samediff.SameDiff} instances.
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class OnnxGraphMapper extends BaseGraphMapper<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto, onnx.Onnx.TypeProto.Tensor> {
|
|
||||||
private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper();
|
|
||||||
|
|
||||||
|
|
||||||
public static OnnxGraphMapper getInstance() {
|
|
||||||
return INSTANCE;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
|
|
||||||
try {
|
|
||||||
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(inputFile);
|
|
||||||
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
|
|
||||||
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
|
|
||||||
bufferedWriter.write(node.toString() + "\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
bufferedWriter.flush();
|
|
||||||
bufferedWriter.close();
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Init a function's attributes
|
|
||||||
* @param mappedTfName the onnx name to pick (sometimes ops have multiple names
|
|
||||||
* @param on the function to map
|
|
||||||
* @param attributesForNode the attributes for the node
|
|
||||||
* @param node
|
|
||||||
* @param graph
|
|
||||||
*/
|
|
||||||
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.NodeProto node, Onnx.GraphProto graph) {
|
|
||||||
val properties = on.mappingsForFunction();
|
|
||||||
val tfProperties = properties.get(mappedTfName);
|
|
||||||
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
|
|
||||||
val attributeAdapters = on.attributeAdaptersForFunction();
|
|
||||||
for(val entry : tfProperties.entrySet()) {
|
|
||||||
val tfAttrName = entry.getValue().getTfAttrName();
|
|
||||||
val currentField = fields.get(entry.getKey());
|
|
||||||
|
|
||||||
AttributeAdapter adapter = null;
|
|
||||||
if(tfAttrName != null) {
|
|
||||||
if(currentField == null) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if(attributeAdapters != null && !attributeAdapters.isEmpty()) {
|
|
||||||
val mappers = attributeAdapters.get(on.tensorflowName());
|
|
||||||
val adapterFor = mappers.get(entry.getKey());
|
|
||||||
adapter = adapterFor;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if(attributesForNode.containsKey(tfAttrName)) {
|
|
||||||
val attr = attributesForNode.get(tfAttrName);
|
|
||||||
switch (attr.getType()) {
|
|
||||||
case STRING:
|
|
||||||
val setString = attr.getS().toStringUtf8();
|
|
||||||
if(adapter != null) {
|
|
||||||
adapter.mapAttributeFor(setString,currentField,on);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
on.setValueFor(currentField,setString);
|
|
||||||
break;
|
|
||||||
case INT:
|
|
||||||
val setInt = (int) attr.getI();
|
|
||||||
if(adapter != null) {
|
|
||||||
adapter.mapAttributeFor(setInt,currentField,on);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
on.setValueFor(currentField,setInt);
|
|
||||||
break;
|
|
||||||
case INTS:
|
|
||||||
val setList = attr.getIntsList();
|
|
||||||
if(!setList.isEmpty()) {
|
|
||||||
val intList = Ints.toArray(setList);
|
|
||||||
if(adapter != null) {
|
|
||||||
adapter.mapAttributeFor(intList,currentField,on);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
on.setValueFor(currentField,intList);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case FLOATS:
|
|
||||||
val floatsList = attr.getFloatsList();
|
|
||||||
if(!floatsList.isEmpty()) {
|
|
||||||
val floats = Floats.toArray(floatsList);
|
|
||||||
if(adapter != null) {
|
|
||||||
adapter.mapAttributeFor(floats,currentField,on);
|
|
||||||
}
|
|
||||||
|
|
||||||
else
|
|
||||||
on.setValueFor(currentField,floats);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case TENSOR:
|
|
||||||
val tensorToGet = mapTensorProto(attr.getT());
|
|
||||||
if(adapter != null) {
|
|
||||||
adapter.mapAttributeFor(tensorToGet,currentField,on);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
on.setValueFor(currentField,tensorToGet);
|
|
||||||
break;
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isOpIgnoreException(Onnx.NodeProto node) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getTargetMappingForOp(DifferentialFunction function, Onnx.NodeProto node) {
|
|
||||||
return function.opName();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void mapProperty(String name, DifferentialFunction on, Onnx.NodeProto node, Onnx.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
|
|
||||||
val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node));
|
|
||||||
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
|
|
||||||
/**
|
|
||||||
* Map ints and the like. Need to figure out how attribute mapping should work.
|
|
||||||
*
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
val propsForFunction = on.propertiesForFunction();
|
|
||||||
|
|
||||||
if(mapping.getTfAttrName() == null) {
|
|
||||||
int tfMappingIdx = mapping.getTfInputPosition();
|
|
||||||
if(tfMappingIdx < 0)
|
|
||||||
tfMappingIdx += node.getInputCount();
|
|
||||||
|
|
||||||
val input = node.getInput(tfMappingIdx);
|
|
||||||
val inputNode = getInstance().getNodeWithNameFromGraph(graph,input);
|
|
||||||
INDArray arr = sameDiff.getArrForVarName(input);
|
|
||||||
val field = fields.get(mapping.getPropertyNames()[0]);
|
|
||||||
val type = field.getType();
|
|
||||||
if(type.equals(int[].class)) {
|
|
||||||
try {
|
|
||||||
field.set(arr.data().asInt(),on);
|
|
||||||
} catch (IllegalAccessException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if(type.equals(int.class) || type.equals(long.class) || type.equals(Long.class) || type.equals(Integer.class)) {
|
|
||||||
try {
|
|
||||||
field.set(arr.getInt(0),on);
|
|
||||||
} catch (IllegalAccessException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if(type.equals(float.class) || type.equals(double.class) || type.equals(Float.class) || type.equals(Double.class)) {
|
|
||||||
try {
|
|
||||||
field.set(arr.getDouble(0),on);
|
|
||||||
} catch (IllegalAccessException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Figure out whether it's an int array
|
|
||||||
* or a double array, or maybe a scalar.
|
|
||||||
*/
|
|
||||||
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
val tfMappingAttrName = mapping.getOnnxAttrName();
|
|
||||||
val attr = getAttrMap(node).get(tfMappingAttrName);
|
|
||||||
val type = attr.getType();
|
|
||||||
val field = fields.get(mapping.getPropertyNames()[0]);
|
|
||||||
|
|
||||||
Object valueToSet = null;
|
|
||||||
switch(type) {
|
|
||||||
case INT:
|
|
||||||
valueToSet = attr.getI();
|
|
||||||
break;
|
|
||||||
case FLOAT:
|
|
||||||
valueToSet = attr.getF();
|
|
||||||
break;
|
|
||||||
case STRING:
|
|
||||||
valueToSet = attr.getF();
|
|
||||||
break;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
field.set(valueToSet,on);
|
|
||||||
} catch (IllegalAccessException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Onnx.NodeProto getNodeWithNameFromGraph(Onnx.GraphProto graph, String name) {
|
|
||||||
for(int i = 0; i < graph.getNodeCount(); i++) {
|
|
||||||
val node = graph.getNode(i);
|
|
||||||
if(node.getName().equals(name))
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isPlaceHolderNode(Onnx.TypeProto.Tensor node) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<String> getControlDependencies(Onnx.NodeProto node) {
|
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
|
|
||||||
try {
|
|
||||||
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
|
|
||||||
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
|
|
||||||
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
|
|
||||||
bufferedWriter.write(node.toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
bufferedWriter.flush();
|
|
||||||
bufferedWriter.close();
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param name the tensorflow or onnx name
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public DifferentialFunction getMappedOp(String name) {
|
|
||||||
return DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String,onnx.Onnx.TypeProto.Tensor> variablesForGraph(Onnx.GraphProto graphProto) {
|
|
||||||
/**
|
|
||||||
* Need to figure out why
|
|
||||||
* gpu_0/conv1_1 isn't present in VGG
|
|
||||||
*/
|
|
||||||
Map<String,onnx.Onnx.TypeProto.Tensor> ret = new HashMap<>();
|
|
||||||
for(int i = 0; i < graphProto.getInputCount(); i++) {
|
|
||||||
ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType());
|
|
||||||
}
|
|
||||||
|
|
||||||
for(int i = 0; i < graphProto.getOutputCount(); i++) {
|
|
||||||
ret.put(graphProto.getOutput(i).getName(),graphProto.getOutput(i).getType().getTensorType());
|
|
||||||
}
|
|
||||||
|
|
||||||
for(int i = 0; i < graphProto.getNodeCount(); i++) {
|
|
||||||
val node = graphProto.getNode(i);
|
|
||||||
val name = node.getName().isEmpty() ? String.valueOf(i) : node.getName();
|
|
||||||
//add -1 as place holder value representing the shape needs to be filled in
|
|
||||||
if(!ret.containsKey(name)) {
|
|
||||||
addDummyTensor(name,ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
for(int j = 0; j < node.getInputCount(); j++) {
|
|
||||||
if(!ret.containsKey(node.getInput(j))) {
|
|
||||||
addDummyTensor(node.getInput(j),ret);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
for(int j = 0; j < node.getOutputCount(); j++) {
|
|
||||||
if(!ret.containsKey(node.getOutput(j))) {
|
|
||||||
addDummyTensor(node.getOutput(j),ret);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String translateToSameDiffName(String name, Onnx.NodeProto node) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
protected void addDummyTensor(String name, Map<String, Onnx.TypeProto.Tensor> to) {
|
|
||||||
Onnx.TensorShapeProto.Dimension dim = Onnx.TensorShapeProto.Dimension.
|
|
||||||
newBuilder()
|
|
||||||
.setDimValue(-1)
|
|
||||||
.build();
|
|
||||||
Onnx.TypeProto.Tensor typeProto = Onnx.TypeProto.Tensor.newBuilder()
|
|
||||||
.setShape(
|
|
||||||
Onnx.TensorShapeProto.newBuilder()
|
|
||||||
.addDim(dim)
|
|
||||||
.addDim(dim).build())
|
|
||||||
.build();
|
|
||||||
to.put(name,typeProto);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Message.Builder getNewGraphBuilder() {
|
|
||||||
return Onnx.GraphProto.newBuilder();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Onnx.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
|
|
||||||
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Onnx.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
|
|
||||||
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void mapNodeType(Onnx.NodeProto tfNode, ImportState<Onnx.GraphProto, Onnx.TypeProto.Tensor> importState,
|
|
||||||
OpImportOverride<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opImportOverride,
|
|
||||||
OpImportFilter<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opFilter) {
|
|
||||||
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType());
|
|
||||||
if(differentialFunction == null) {
|
|
||||||
throw new NoOpNameFoundException("No op name found " + tfNode.getOpType());
|
|
||||||
}
|
|
||||||
|
|
||||||
val diff = importState.getSameDiff();
|
|
||||||
val idx = importState.getGraph().getNodeList().indexOf(tfNode);
|
|
||||||
val name = !tfNode.getName().isEmpty() ? tfNode.getName() : String.valueOf(idx);
|
|
||||||
try {
|
|
||||||
val newInstance = differentialFunction.getClass().newInstance();
|
|
||||||
val args = new SDVariable[tfNode.getInputCount()];
|
|
||||||
|
|
||||||
newInstance.setSameDiff(importState.getSameDiff());
|
|
||||||
|
|
||||||
newInstance.initFromOnnx(tfNode,diff,getAttrMap(tfNode),importState.getGraph());
|
|
||||||
importState.getSameDiff().putOpForId(newInstance.getOwnName(),newInstance);
|
|
||||||
//ensure we can track node name to function instance later.
|
|
||||||
diff.setBaseNameForFunctionInstanceId(tfNode.getName(),newInstance);
|
|
||||||
//diff.addVarNameForImport(tfNode.getName());
|
|
||||||
}
|
|
||||||
catch (Exception e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataType dataTypeForTensor(Onnx.TypeProto.Tensor tensorProto, int outputNum) {
|
|
||||||
return nd4jTypeFromOnnxType(tensorProto.getElemType());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isStringType(Onnx.TypeProto.Tensor tensor) {
|
|
||||||
return tensor.getElemType() == Onnx.TensorProto.DataType.STRING;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert an onnx type to the proper nd4j type
|
|
||||||
* @param dataType the data type to convert
|
|
||||||
* @return the nd4j type for the onnx type
|
|
||||||
*/
|
|
||||||
public DataType nd4jTypeFromOnnxType(Onnx.TensorProto.DataType dataType) {
|
|
||||||
switch (dataType) {
|
|
||||||
case DOUBLE: return DataType.DOUBLE;
|
|
||||||
case FLOAT: return DataType.FLOAT;
|
|
||||||
case FLOAT16: return DataType.HALF;
|
|
||||||
case INT32:
|
|
||||||
case INT64: return DataType.INT;
|
|
||||||
default: return DataType.UNKNOWN;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getAttrValueFromNode(Onnx.NodeProto nodeProto, String key) {
|
|
||||||
for(Onnx.AttributeProto attributeProto : nodeProto.getAttributeList()) {
|
|
||||||
if(attributeProto.getName().equals(key)) {
|
|
||||||
return attributeProto.getS().toString();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
throw new ND4JIllegalStateException("No key found for " + key);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long[] getShapeFromAttribute(Onnx.AttributeProto attributeProto) {
|
|
||||||
return Longs.toArray(attributeProto.getT().getDimsList());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isPlaceHolder(Onnx.TypeProto.Tensor nodeType) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isConstant(Onnx.TypeProto.Tensor nodeType) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getNDArrayFromTensor(String tensorName, Onnx.TypeProto.Tensor tensorProto, Onnx.GraphProto graph) {
|
|
||||||
DataType type = dataTypeForTensor(tensorProto, 0);
|
|
||||||
if(!tensorProto.isInitialized()) {
|
|
||||||
throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
|
|
||||||
}
|
|
||||||
|
|
||||||
Onnx.TensorProto tensor = null;
|
|
||||||
for(int i = 0; i < graph.getInitializerCount(); i++) {
|
|
||||||
val initializer = graph.getInitializer(i);
|
|
||||||
if(initializer.getName().equals(tensorName)) {
|
|
||||||
tensor = initializer;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if(tensor == null)
|
|
||||||
return null;
|
|
||||||
|
|
||||||
ByteString bytes = tensor.getRawData();
|
|
||||||
ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
|
|
||||||
ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
|
|
||||||
directAlloc.put(byteBuffer);
|
|
||||||
directAlloc.rewind();
|
|
||||||
long[] shape = getShapeFromTensor(tensorProto);
|
|
||||||
DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape));
|
|
||||||
INDArray arr = Nd4j.create(buffer).reshape(shape);
|
|
||||||
return arr;
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray mapTensorProto(Onnx.TensorProto tensor) {
|
|
||||||
if(tensor == null)
|
|
||||||
return null;
|
|
||||||
|
|
||||||
|
|
||||||
DataType type = nd4jTypeFromOnnxType(tensor.getDataType());
|
|
||||||
|
|
||||||
ByteString bytes = tensor.getRawData();
|
|
||||||
ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
|
|
||||||
ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
|
|
||||||
directAlloc.put(byteBuffer);
|
|
||||||
directAlloc.rewind();
|
|
||||||
long[] shape = getShapeFromTensor(tensor);
|
|
||||||
DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape));
|
|
||||||
INDArray arr = Nd4j.create(buffer).reshape(shape);
|
|
||||||
return arr;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long[] getShapeFromTensor(onnx.Onnx.TypeProto.Tensor tensorProto) {
|
|
||||||
val ret = new long[Math.max(2,tensorProto.getShape().getDimCount())];
|
|
||||||
int dimCount = tensorProto.getShape().getDimCount();
|
|
||||||
if(dimCount >= 2)
|
|
||||||
for(int i = 0; i < ret.length; i++) {
|
|
||||||
ret[i] = (int) tensorProto.getShape().getDim(i).getDimValue();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
ret[0] = 1;
|
|
||||||
for(int i = 1; i < ret.length; i++) {
|
|
||||||
ret[i] = (int) tensorProto.getShape().getDim(i - 1).getDimValue();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the shape from a tensor proto.
|
|
||||||
* Note that this is different from {@link #getShapeFromTensor(Onnx.TensorProto)}
|
|
||||||
* @param tensorProto the tensor to get the shape from
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public long[] getShapeFromTensor(Onnx.TensorProto tensorProto) {
|
|
||||||
val ret = new long[Math.max(2,tensorProto.getDimsCount())];
|
|
||||||
int dimCount = tensorProto.getDimsCount();
|
|
||||||
if(dimCount >= 2)
|
|
||||||
for(int i = 0; i < ret.length; i++) {
|
|
||||||
ret[i] = (int) tensorProto.getDims(i);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
ret[0] = 1;
|
|
||||||
for(int i = 1; i < ret.length; i++) {
|
|
||||||
ret[i] = (int) tensorProto.getDims(i - 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Set<String> opsToIgnore() {
|
|
||||||
return Collections.emptySet();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getInputFromNode(Onnx.NodeProto node, int index) {
|
|
||||||
return node.getInput(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numInputsFor(Onnx.NodeProto nodeProto) {
|
|
||||||
return nodeProto.getInputCount();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long[] getShapeFromAttr(Onnx.AttributeProto attr) {
|
|
||||||
return Longs.toArray(attr.getT().getDimsList());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, Onnx.AttributeProto> getAttrMap(Onnx.NodeProto nodeProto) {
|
|
||||||
Map<String,Onnx.AttributeProto> proto = new HashMap<>();
|
|
||||||
for(int i = 0; i < nodeProto.getAttributeCount(); i++) {
|
|
||||||
Onnx.AttributeProto attributeProto = nodeProto.getAttribute(i);
|
|
||||||
proto.put(attributeProto.getName(),attributeProto);
|
|
||||||
}
|
|
||||||
return proto;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getName(Onnx.NodeProto nodeProto) {
|
|
||||||
return nodeProto.getName();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean alreadySeen(Onnx.NodeProto nodeProto) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isVariableNode(Onnx.NodeProto nodeProto) {
|
|
||||||
return nodeProto.getOpType().contains("Var");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean shouldSkip(Onnx.NodeProto opType) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasShape(Onnx.NodeProto nodeProto) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long[] getShape(Onnx.NodeProto nodeProto) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getArrayFrom(Onnx.NodeProto nodeProto, Onnx.GraphProto graph) {
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getOpType(Onnx.NodeProto nodeProto) {
|
|
||||||
return nodeProto.getOpType();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Onnx.NodeProto> getNodeList(Onnx.GraphProto graphProto) {
|
|
||||||
return graphProto.getNodeList();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
|
@ -226,22 +226,24 @@ public class TensorFlowImportValidator {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static TFImportStatus checkModelForImport(String path, InputStream is, boolean exceptionOnRead) throws IOException {
|
public static TFImportStatus checkModelForImport(String path, InputStream is, boolean exceptionOnRead) throws IOException {
|
||||||
TFGraphMapper m = TFGraphMapper.getInstance();
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
int opCount = 0;
|
int opCount = 0;
|
||||||
Set<String> opNames = new HashSet<>();
|
Set<String> opNames = new HashSet<>();
|
||||||
|
|
||||||
try(InputStream bis = new BufferedInputStream(is)) {
|
try(InputStream bis = new BufferedInputStream(is)) {
|
||||||
GraphDef graphDef = m.parseGraphFrom(bis);
|
GraphDef graphDef = GraphDef.parseFrom(bis);
|
||||||
List<NodeDef> nodes = m.getNodeList(graphDef);
|
List<NodeDef> nodes = new ArrayList<>(graphDef.getNodeCount());
|
||||||
|
for( int i=0; i<graphDef.getNodeCount(); i++ ){
|
||||||
|
nodes.add(graphDef.getNode(i));
|
||||||
|
}
|
||||||
|
|
||||||
if(nodes.isEmpty()){
|
if(nodes.isEmpty()){
|
||||||
throw new IllegalStateException("Error loading model for import - loaded graph def has no nodes (empty/corrupt file?): " + path);
|
throw new IllegalStateException("Error loading model for import - loaded graph def has no nodes (empty/corrupt file?): " + path);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (NodeDef nd : nodes) {
|
for (NodeDef nd : nodes) {
|
||||||
if (m.isVariableNode(nd) || m.isPlaceHolderNode(nd))
|
if (TFGraphMapper.isVariableNode(nd) || TFGraphMapper.isPlaceHolder(nd))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
String op = nd.getOp();
|
String op = nd.getOp();
|
||||||
|
|
|
@ -86,6 +86,7 @@ import java.io.*;
|
||||||
import java.nio.IntBuffer;
|
import java.nio.IntBuffer;
|
||||||
import java.nio.LongBuffer;
|
import java.nio.LongBuffer;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
|
|
||||||
import static org.nd4j.linalg.factory.Nd4j.*;
|
import static org.nd4j.linalg.factory.Nd4j.*;
|
||||||
|
|
||||||
|
@ -124,6 +125,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
protected transient JvmShapeInfo jvmShapeInfo;
|
protected transient JvmShapeInfo jvmShapeInfo;
|
||||||
|
|
||||||
|
|
||||||
|
private static final AtomicLong arrayCounter = new AtomicLong(0);
|
||||||
|
protected transient final long arrayId = arrayCounter.getAndIncrement();
|
||||||
|
|
||||||
|
|
||||||
//Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over
|
//Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over
|
||||||
private static final int[][] tadFinalPermuteDimensions;
|
private static final int[][] tadFinalPermuteDimensions;
|
||||||
|
@ -139,7 +143,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
}
|
}
|
||||||
|
|
||||||
public BaseNDArray() {
|
public BaseNDArray() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -4916,6 +4919,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString(@NonNull NDArrayStrings options){
|
public String toString(@NonNull NDArrayStrings options){
|
||||||
|
if(wasClosed())
|
||||||
|
return "<Closed NDArray, id=" + getId() + ", dtype=" + dataType() + ", shape=" + Arrays.toString(shape()) + ">";
|
||||||
if (!isCompressed() && !preventUnpack)
|
if (!isCompressed() && !preventUnpack)
|
||||||
return options.format(this);
|
return options.format(this);
|
||||||
else if (isCompressed() && compressDebug)
|
else if (isCompressed() && compressDebug)
|
||||||
|
@ -5600,4 +5605,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getId(){
|
||||||
|
return arrayId;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2814,4 +2814,10 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
* @see org.nd4j.linalg.api.ndarray.BaseNDArray#toString(long, boolean, int)
|
* @see org.nd4j.linalg.api.ndarray.BaseNDArray#toString(long, boolean, int)
|
||||||
*/
|
*/
|
||||||
String toStringFull();
|
String toStringFull();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A unique ID for the INDArray object instance. Does not account for views.
|
||||||
|
* @return INDArray unique ID
|
||||||
|
*/
|
||||||
|
long getId();
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -200,48 +201,17 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setX(INDArray x) {
|
public void setX(INDArray x) {
|
||||||
if (x == null) {
|
this.x = x;
|
||||||
if (args() != null && args().length >= 1) {
|
|
||||||
SDVariable firstArg = args()[0];
|
|
||||||
if (firstArg.getArr() != null)
|
|
||||||
this.x = firstArg.getArr();
|
|
||||||
} else
|
|
||||||
throw new ND4JIllegalStateException("Unable to set null array for x. Also unable to infer from differential function arguments");
|
|
||||||
} else
|
|
||||||
this.x = x;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setZ(INDArray z) {
|
public void setZ(INDArray z) {
|
||||||
if (z == null) {
|
this.z = z;
|
||||||
SDVariable getResult = sameDiff.getVariable(zVertexId);
|
|
||||||
if (getResult != null) {
|
|
||||||
if (getResult.getArr() != null)
|
|
||||||
this.z = getResult.getArr();
|
|
||||||
else if(sameDiff.getShapeForVarName(getResult.getVarName()) != null) {
|
|
||||||
val shape = sameDiff.getShapeForVarName(getResult.getVarName());
|
|
||||||
sameDiff.setArrayForVariable(getResult.getVarName(),getResult.getWeightInitScheme().create(getResult.dataType(), shape));
|
|
||||||
}
|
|
||||||
else
|
|
||||||
throw new ND4JIllegalStateException("Unable to set null array for z. Also unable to infer from differential function arguments");
|
|
||||||
|
|
||||||
} else
|
|
||||||
throw new ND4JIllegalStateException("Unable to set null array for z. Also unable to infer from differential function arguments");
|
|
||||||
} else
|
|
||||||
this.z = z;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setY(INDArray y) {
|
public void setY(INDArray y) {
|
||||||
if (y == null) {
|
this.y = y;
|
||||||
if (args() != null && args().length > 1) {
|
|
||||||
SDVariable firstArg = args()[1];
|
|
||||||
if (firstArg.getArr() != null)
|
|
||||||
this.y = firstArg.getArr();
|
|
||||||
} else
|
|
||||||
throw new ND4JIllegalStateException("Unable to set null array for y. Also unable to infer from differential function arguments");
|
|
||||||
} else
|
|
||||||
this.y = y;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -265,6 +235,12 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
||||||
return z;
|
return z;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray getInputArgument(int index){
|
||||||
|
Preconditions.checkState(index >= 0 && index < 2, "Input argument index must be 0 or 1, got %s", index);
|
||||||
|
return index == 0 ? x : y;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SDVariable[] outputVariables(String baseName) {
|
public SDVariable[] outputVariables(String baseName) {
|
||||||
if(zVertexId == null) {
|
if(zVertexId == null) {
|
||||||
|
@ -403,4 +379,11 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
||||||
//Always 1 for legacy/base ops
|
//Always 1 for legacy/base ops
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void clearArrays(){
|
||||||
|
x = null;
|
||||||
|
y = null;
|
||||||
|
z = null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops;
|
package org.nd4j.linalg.api.ops;
|
||||||
|
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -24,21 +23,14 @@ import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
@ -71,10 +63,6 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
|
||||||
this.keepDims = keepDims;
|
this.keepDims = keepDims;
|
||||||
this.xVertexId = i_v.getVarName();
|
this.xVertexId = i_v.getVarName();
|
||||||
sameDiff.addArgsFor(new String[]{xVertexId},this);
|
sameDiff.addArgsFor(new String[]{xVertexId},this);
|
||||||
if(Shape.isPlaceholderShape(i_v.getShape())) {
|
|
||||||
sameDiff.addPropertyToResolve(this,i_v.getVarName());
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalArgumentException("Input not null variable.");
|
throw new IllegalArgumentException("Input not null variable.");
|
||||||
}
|
}
|
||||||
|
@ -219,14 +207,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
if (!attributesForNode.containsKey("axes")) {
|
|
||||||
this.dimensions = new int[] { Integer.MAX_VALUE };
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
val map = OnnxGraphMapper.getInstance().getAttrMap(node);
|
|
||||||
val dims = Ints.toArray(map.get("axes").getIntsList());
|
|
||||||
this.dimensions = dims;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -119,4 +119,9 @@ public interface CustomOp {
|
||||||
* otherwise throws an {@link org.nd4j.linalg.exception.ND4JIllegalStateException}
|
* otherwise throws an {@link org.nd4j.linalg.exception.ND4JIllegalStateException}
|
||||||
*/
|
*/
|
||||||
void assertValidForExecution();
|
void assertValidForExecution();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear the input and output INDArrays, if any are set
|
||||||
|
*/
|
||||||
|
void clearArrays();
|
||||||
}
|
}
|
||||||
|
|
|
@ -263,7 +263,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] outputArguments() {
|
public INDArray[] outputArguments() {
|
||||||
if (!outputArguments.isEmpty()) {
|
if (!outputArguments.isEmpty()) {
|
||||||
return outputArguments.toArray(new INDArray[outputArguments.size()]);
|
return outputArguments.toArray(new INDArray[0]);
|
||||||
}
|
}
|
||||||
return new INDArray[0];
|
return new INDArray[0];
|
||||||
}
|
}
|
||||||
|
@ -271,7 +271,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] inputArguments() {
|
public INDArray[] inputArguments() {
|
||||||
if (!inputArguments.isEmpty())
|
if (!inputArguments.isEmpty())
|
||||||
return inputArguments.toArray(new INDArray[inputArguments.size()]);
|
return inputArguments.toArray(new INDArray[0]);
|
||||||
return new INDArray[0];
|
return new INDArray[0];
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -389,6 +389,13 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setInputArgument(int index, INDArray input) {
|
public void setInputArgument(int index, INDArray input) {
|
||||||
|
if(index >= inputArguments.size() ){
|
||||||
|
List<INDArray> oldArgs = inputArguments;
|
||||||
|
inputArguments = new ArrayList<>(index+1);
|
||||||
|
inputArguments.addAll(oldArgs);
|
||||||
|
while(inputArguments.size() <= index)
|
||||||
|
inputArguments.add(null);
|
||||||
|
}
|
||||||
inputArguments.set(index, input);
|
inputArguments.set(index, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -400,12 +407,12 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setOutputArgument(int index, INDArray output) {
|
public void setOutputArgument(int index, INDArray output) {
|
||||||
if(index == outputArguments.size()){
|
while(index >= outputArguments.size()){
|
||||||
//For example, setOutputArgument(0,arr) on empty list
|
//Resize list, in case we want to specify arrays not in order they are defined
|
||||||
outputArguments.add(output);
|
//For example, index 1 on empty list, then index 0
|
||||||
} else {
|
outputArguments.add(null);
|
||||||
outputArguments.set(index, output);
|
|
||||||
}
|
}
|
||||||
|
outputArguments.set(index, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -608,6 +615,12 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void clearArrays(){
|
||||||
|
inputArguments.clear();
|
||||||
|
outputArguments.clear();
|
||||||
|
}
|
||||||
|
|
||||||
protected static INDArray[] wrapOrNull(INDArray in){
|
protected static INDArray[] wrapOrNull(INDArray in){
|
||||||
return in == null ? null : new INDArray[]{in};
|
return in == null ? null : new INDArray[]{in};
|
||||||
}
|
}
|
||||||
|
|
|
@ -167,4 +167,9 @@ public interface Op {
|
||||||
* @return the equivalent {@link CustomOp}
|
* @return the equivalent {@link CustomOp}
|
||||||
*/
|
*/
|
||||||
CustomOp toCustomOp();
|
CustomOp toCustomOp();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear the input and output INDArrays, if any are set
|
||||||
|
*/
|
||||||
|
void clearArrays();
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,6 @@ public class AdjustContrastV2 extends BaseAdjustContrast {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "AdjustContrast";
|
return "AdjustContrastV2";
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -245,4 +245,9 @@ public class ScatterUpdate implements CustomOp {
|
||||||
public void assertValidForExecution() {
|
public void assertValidForExecution() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void clearArrays() {
|
||||||
|
op.clearArrays();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,13 +39,18 @@ import java.util.*;
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class BiasAdd extends DynamicCustomOp {
|
public class BiasAdd extends DynamicCustomOp {
|
||||||
|
|
||||||
|
protected boolean nchw = true;
|
||||||
|
|
||||||
public BiasAdd(SameDiff sameDiff, SDVariable input, SDVariable bias) {
|
public BiasAdd(SameDiff sameDiff, SDVariable input, SDVariable bias, boolean nchw) {
|
||||||
super(null, sameDiff, new SDVariable[] {input, bias}, false);
|
super(null, sameDiff, new SDVariable[] {input, bias}, false);
|
||||||
|
bArguments.clear();
|
||||||
|
bArguments.add(nchw);
|
||||||
}
|
}
|
||||||
|
|
||||||
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output){
|
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){
|
||||||
super(new INDArray[]{input, bias}, wrapOrNull(output));
|
super(new INDArray[]{input, bias}, wrapOrNull(output));
|
||||||
|
bArguments.clear();
|
||||||
|
bArguments.add(nchw);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -56,7 +61,11 @@ public class BiasAdd extends DynamicCustomOp {
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
|
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
|
||||||
|
if(attributesForNode.containsKey("data_format")){
|
||||||
|
nchw = "NCHW".equalsIgnoreCase(attributesForNode.get("data_format").getS().toStringUtf8());
|
||||||
|
}
|
||||||
|
bArguments.clear();
|
||||||
|
bArguments.add(nchw);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1,402 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.controlflow;
|
|
||||||
|
|
||||||
import lombok.*;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import onnx.Onnx;
|
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiffConditional;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
|
||||||
import org.nd4j.linalg.util.HashUtil;
|
|
||||||
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
|
||||||
import org.tensorflow.framework.AttrValue;
|
|
||||||
import org.tensorflow.framework.GraphDef;
|
|
||||||
import org.tensorflow.framework.NodeDef;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Equivalent to tensorflow's conditional op.
|
|
||||||
* Runs one of 2 {@link SameDiff.SameDiffFunctionDefinition}
|
|
||||||
* depending on a predicate {@link org.nd4j.autodiff.samediff.SameDiff.SameDiffConditional}
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
@NoArgsConstructor
|
|
||||||
@Slf4j
|
|
||||||
public class If extends DifferentialFunction implements CustomOp {
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected SameDiff loopBodyExecution,predicateExecution,falseBodyExecution;
|
|
||||||
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected SameDiffConditional predicate;
|
|
||||||
@Getter
|
|
||||||
protected SameDiffFunctionDefinition trueBody,falseBody;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected String blockName,trueBodyName,falseBodyName;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected SDVariable[] inputVars;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected Boolean trueBodyExecuted = null;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected SDVariable targetBoolean;
|
|
||||||
|
|
||||||
protected SDVariable dummyResult;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
protected SDVariable[] outputVars;
|
|
||||||
|
|
||||||
public If(If ifStatement) {
|
|
||||||
this.sameDiff = ifStatement.sameDiff;
|
|
||||||
this.outputVars = ifStatement.outputVars;
|
|
||||||
this.falseBodyExecution = ifStatement.falseBodyExecution;
|
|
||||||
this.trueBodyExecuted = ifStatement.trueBodyExecuted;
|
|
||||||
this.falseBody = ifStatement.falseBody;
|
|
||||||
this.trueBodyExecuted = ifStatement.trueBodyExecuted;
|
|
||||||
this.dummyResult = ifStatement.dummyResult;
|
|
||||||
this.inputVars = ifStatement.inputVars;
|
|
||||||
this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme(), DataType.FLOAT, 1);
|
|
||||||
if(sameDiff.getShapeForVarName(dummyResult.getVarName()) == null)
|
|
||||||
sameDiff.putShapeForVarName(dummyResult.getVarName(),new long[]{1,1});
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
public If(String blockName,
|
|
||||||
SameDiff parent,
|
|
||||||
SDVariable[] inputVars,
|
|
||||||
SameDiffFunctionDefinition conditionBody,
|
|
||||||
SameDiffConditional predicate,
|
|
||||||
SameDiffFunctionDefinition trueBody,
|
|
||||||
SameDiffFunctionDefinition falseBody) {
|
|
||||||
|
|
||||||
this.sameDiff = parent;
|
|
||||||
parent.putOpForId(getOwnName(),this);
|
|
||||||
this.inputVars = inputVars;
|
|
||||||
this.predicate = predicate;
|
|
||||||
|
|
||||||
parent.addArgsFor(inputVars,this);
|
|
||||||
this.trueBody = trueBody;
|
|
||||||
this.falseBody = falseBody;
|
|
||||||
this.blockName = blockName;
|
|
||||||
//need to add the op to the list of ops to be executed when running backwards
|
|
||||||
this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1);
|
|
||||||
parent.addOutgoingFor(new SDVariable[]{dummyResult},this);
|
|
||||||
|
|
||||||
//create a samediff sub graph for running just the execution
|
|
||||||
//return a reference to the loop for referencing during actual execution
|
|
||||||
SameDiff sameDiff = SameDiff.create();
|
|
||||||
//store the reference to the result array and the same diff execution instance
|
|
||||||
this.targetBoolean = predicate.eval(sameDiff,conditionBody, inputVars);
|
|
||||||
this.predicateExecution = sameDiff;
|
|
||||||
//store references to the loop body
|
|
||||||
String trueBodyName = "true-body-" + UUID.randomUUID().toString();
|
|
||||||
this.trueBodyName = trueBodyName;
|
|
||||||
|
|
||||||
String falseBodyName = "false-body-" + UUID.randomUUID().toString();
|
|
||||||
this.falseBodyName = trueBodyName;
|
|
||||||
|
|
||||||
//running define function will setup a proper same diff instance
|
|
||||||
this.loopBodyExecution = parent.defineFunction(trueBodyName,trueBody,inputVars);
|
|
||||||
this.falseBodyExecution = parent.defineFunction(falseBodyName,falseBody,inputVars);
|
|
||||||
parent.defineFunction(blockName,conditionBody,inputVars);
|
|
||||||
parent.putSubFunction("predicate-eval-body-" + UUID.randomUUID().toString(),sameDiff);
|
|
||||||
//get a reference to the actual loop body
|
|
||||||
this.loopBodyExecution = parent.getFunction(trueBodyName);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Toggle whether the true body was executed
|
|
||||||
* or the false body
|
|
||||||
* @param trueBodyExecuted
|
|
||||||
*/
|
|
||||||
public void exectedTrueOrFalse(boolean trueBodyExecuted) {
|
|
||||||
if(trueBodyExecuted)
|
|
||||||
this.trueBodyExecuted = true;
|
|
||||||
else
|
|
||||||
this.trueBodyExecuted = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SDVariable[] outputVariables(String baseName) {
|
|
||||||
return new SDVariable[]{dummyResult};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
List<SDVariable> ret = new ArrayList<>();
|
|
||||||
ret.addAll(Arrays.asList(new IfDerivative(this).outputVariables()));
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return opName();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "if";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long opHash() {
|
|
||||||
return HashUtil.getLongHash(opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isInplaceCall() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray[] outputArguments() {
|
|
||||||
return new INDArray[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray[] inputArguments() {
|
|
||||||
return new INDArray[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long[] iArgs() {
|
|
||||||
return new long[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] tArgs() {
|
|
||||||
return new double[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean[] bArgs() {
|
|
||||||
return new boolean[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addIArgument(int... arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addIArgument(long... arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addBArgument(boolean... arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeIArgument(Integer arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Boolean getBArgument(int index) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Long getIArgument(int index) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numIArguments() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addTArgument(double... arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeTArgument(Double arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Double getTArgument(int index) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numTArguments() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numBArguments() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addInputArgument(INDArray... arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeInputArgument(INDArray arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getInputArgument(int index) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numInputArguments() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addOutputArgument(INDArray... arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeOutputArgument(INDArray arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getOutputArgument(int index) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numOutputArguments() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Op.Type opType() {
|
|
||||||
return Op.Type.CONDITIONAL;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
|
||||||
//cond is only part of while loops
|
|
||||||
if(nodeDef.getName().contains("/cond/"))
|
|
||||||
return;
|
|
||||||
//usually should be a merge node for a conditional
|
|
||||||
val ifNodes = TFGraphMapper.getInstance().nodesForIf(nodeDef,graph);
|
|
||||||
|
|
||||||
|
|
||||||
val trueScopeGraphDefBuilder = GraphDef.newBuilder();
|
|
||||||
for(val node : ifNodes.getTrueNodes()) {
|
|
||||||
trueScopeGraphDefBuilder.addNode(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
val trueScope = TFGraphMapper.getInstance().importGraph(trueScopeGraphDefBuilder.build());
|
|
||||||
|
|
||||||
|
|
||||||
val falseScopeGraphDefBuilder = GraphDef.newBuilder();
|
|
||||||
for(val node : ifNodes.getFalseNodes()) {
|
|
||||||
falseScopeGraphDefBuilder.addNode(node);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
val falseScope = TFGraphMapper.getInstance().importGraph(falseScopeGraphDefBuilder.build());
|
|
||||||
|
|
||||||
|
|
||||||
val condScopeGraphDefBuilder = GraphDef.newBuilder();
|
|
||||||
for(val node : ifNodes.getCondNodes()) {
|
|
||||||
condScopeGraphDefBuilder.addNode(node);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
val condScope = TFGraphMapper.getInstance().importGraph(condScopeGraphDefBuilder.build());
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
initWith.putSubFunction(ifNodes.getTrueBodyScopeName(),trueScope);
|
|
||||||
initWith.putSubFunction(ifNodes.getFalseBodyScopeName(),falseScope);
|
|
||||||
initWith.putSubFunction(ifNodes.getConditionBodyScopeName(),condScope);
|
|
||||||
|
|
||||||
this.loopBodyExecution = trueScope;
|
|
||||||
this.falseBodyExecution = falseScope;
|
|
||||||
this.predicateExecution = condScope;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
return Arrays.asList(LongShapeDescriptor.fromShape(new long[0], DataType.BOOL));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public CustomOpDescriptor getDescriptor() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void assertValidForExecution() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("This operation has no TF counterpart");
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,93 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.controlflow;
|
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiffConditional;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
|
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class IfDerivative extends If {
|
|
||||||
|
|
||||||
private If ifDelegate;
|
|
||||||
|
|
||||||
public IfDerivative(If ifBlock) {
|
|
||||||
super(ifBlock);
|
|
||||||
this.ifDelegate = ifBlock;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Boolean getTrueBodyExecuted() {
|
|
||||||
return ifDelegate.trueBodyExecuted;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiffFunctionDefinition getFalseBody() {
|
|
||||||
return ifDelegate.falseBody;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiff getFalseBodyExecution() {
|
|
||||||
return ifDelegate.falseBodyExecution;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getBlockName() {
|
|
||||||
return ifDelegate.blockName;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getFalseBodyName() {
|
|
||||||
return ifDelegate.falseBodyName;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiff getLoopBodyExecution() {
|
|
||||||
return ifDelegate.loopBodyExecution;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiffConditional getPredicate() {
|
|
||||||
return ifDelegate.getPredicate();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiff getPredicateExecution() {
|
|
||||||
return ifDelegate.predicateExecution;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
return super.calculateOutputShape();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "if_bp";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> diff(List<SDVariable> i_v1) {
|
|
||||||
throw new UnsupportedOperationException("Unable to take the derivative of the derivative for if");
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,32 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.controlflow;
|
|
||||||
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import org.tensorflow.framework.NodeDef;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
@Data
|
|
||||||
public class IfImportState {
|
|
||||||
private List<NodeDef> condNodes;
|
|
||||||
private List<NodeDef> trueNodes;
|
|
||||||
private List<NodeDef> falseNodes;
|
|
||||||
private String falseBodyScopeName,trueBodyScopeName,conditionBodyScopeName;
|
|
||||||
}
|
|
|
@ -55,7 +55,7 @@ public class Select extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,660 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.controlflow;
|
|
||||||
|
|
||||||
import lombok.*;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import onnx.Onnx;
|
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiffConditional;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
|
||||||
import org.tensorflow.framework.AttrValue;
|
|
||||||
import org.tensorflow.framework.GraphDef;
|
|
||||||
import org.tensorflow.framework.NodeDef;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Equivalent to tensorflow's while loop
|
|
||||||
* Takes in:
|
|
||||||
* loopVars
|
|
||||||
* loop body
|
|
||||||
* condition
|
|
||||||
*
|
|
||||||
* runs loop till condition is false.
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
@NoArgsConstructor
|
|
||||||
@Slf4j
|
|
||||||
public class While extends DifferentialFunction implements CustomOp {
|
|
||||||
private AtomicInteger startPosition;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected SameDiff loopBodyExecution,predicateExecution;
|
|
||||||
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected SameDiffConditional predicate;
|
|
||||||
@Getter
|
|
||||||
protected SameDiffFunctionDefinition trueBody;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected String blockName,trueBodyName;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected SDVariable[] inputVars;
|
|
||||||
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected SDVariable targetBoolean;
|
|
||||||
|
|
||||||
protected SDVariable dummyResult;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
protected SDVariable[] outputVars;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected int numLooped = 0;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Mainly meant for tensorflow import.
|
|
||||||
* This allows {@link #initFromTensorFlow(NodeDef, SameDiff, Map, GraphDef)}
|
|
||||||
* to continue from a parent while loop
|
|
||||||
* using the same graph
|
|
||||||
* @param startPosition the start position for the import scan
|
|
||||||
*/
|
|
||||||
public While(AtomicInteger startPosition) {
|
|
||||||
this.startPosition = startPosition;
|
|
||||||
}
|
|
||||||
|
|
||||||
public While(While whileStatement) {
|
|
||||||
this.sameDiff = whileStatement.sameDiff;
|
|
||||||
this.outputVars = whileStatement.outputVars;
|
|
||||||
this.loopBodyExecution = whileStatement.loopBodyExecution;
|
|
||||||
this.numLooped = whileStatement.numLooped;
|
|
||||||
this.dummyResult = whileStatement.dummyResult;
|
|
||||||
this.predicate = whileStatement.predicate;
|
|
||||||
this.predicateExecution = whileStatement.predicateExecution;
|
|
||||||
this.inputVars = whileStatement.inputVars;
|
|
||||||
this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
public While(String blockName,
|
|
||||||
SameDiff parent,
|
|
||||||
SDVariable[] inputVars,
|
|
||||||
SameDiffConditional predicate,
|
|
||||||
SameDiffFunctionDefinition condition,
|
|
||||||
SameDiffFunctionDefinition trueBody) {
|
|
||||||
init(blockName,parent,inputVars,predicate,condition,trueBody);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private void init(String blockName,
|
|
||||||
SameDiff parent,
|
|
||||||
SDVariable[] inputVars,
|
|
||||||
SameDiffConditional predicate,
|
|
||||||
SameDiffFunctionDefinition condition,
|
|
||||||
SameDiffFunctionDefinition trueBody) {
|
|
||||||
this.sameDiff = parent;
|
|
||||||
this.inputVars = inputVars;
|
|
||||||
this.predicate = predicate;
|
|
||||||
this.trueBody = trueBody;
|
|
||||||
this.blockName = blockName;
|
|
||||||
this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1);
|
|
||||||
parent.putOpForId(getOwnName(),this);
|
|
||||||
|
|
||||||
parent.addArgsFor(inputVars,this);
|
|
||||||
parent.addOutgoingFor(new SDVariable[]{dummyResult},this);
|
|
||||||
|
|
||||||
|
|
||||||
//create a samediff sub graph for running just the execution
|
|
||||||
//return a reference to the loop for referencing during actual execution
|
|
||||||
SameDiff sameDiff = SameDiff.create();
|
|
||||||
//store the reference to the result array and the same diff execution instance
|
|
||||||
this.targetBoolean = predicate.eval(sameDiff,condition, inputVars);
|
|
||||||
this.predicateExecution = sameDiff;
|
|
||||||
//store references to the loop body
|
|
||||||
String trueBodyName = "true-body-" + UUID.randomUUID().toString();
|
|
||||||
this.trueBodyName = trueBodyName;
|
|
||||||
//running define function will setup a proper same diff instance
|
|
||||||
parent.defineFunction(trueBodyName,trueBody,inputVars);
|
|
||||||
parent.defineFunction(blockName,condition,inputVars);
|
|
||||||
parent.putSubFunction("predicate-eval-body",sameDiff);
|
|
||||||
//get a reference to the actual loop body
|
|
||||||
this.loopBodyExecution = parent.getFunction(trueBodyName);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SDVariable[] outputVariables(String baseName) {
|
|
||||||
return new SDVariable[]{dummyResult};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
List<SDVariable> ret = new ArrayList<>();
|
|
||||||
ret.addAll(Arrays.asList(new WhileDerivative(this).outputVariables()));
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Increments the loop counter.
|
|
||||||
* This should be called when the loop
|
|
||||||
* actually executes.
|
|
||||||
*/
|
|
||||||
public void incrementLoopCounter() {
|
|
||||||
numLooped++;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
|
||||||
doImport(nodeDef,initWith,attributesForNode,graph,new LinkedHashSet<String>(),new AtomicInteger(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private void doImport(NodeDef nodeDef,SameDiff initWith,Map<String,AttrValue> attributesForNode,GraphDef graph,Set<String> skipSet,AtomicInteger currIndex) {
|
|
||||||
val uniqueId = java.util.UUID.randomUUID().toString();
|
|
||||||
skipSet.add(nodeDef.getName());
|
|
||||||
val scopeCondition = SameDiff.create();
|
|
||||||
val scopeLoop = SameDiff.create();
|
|
||||||
initWith.putSubFunction("condition-" + uniqueId,scopeCondition);
|
|
||||||
initWith.putSubFunction("loopbody-" + uniqueId,scopeLoop);
|
|
||||||
this.loopBodyExecution = scopeLoop;
|
|
||||||
this.predicateExecution = scopeCondition;
|
|
||||||
this.startPosition = currIndex;
|
|
||||||
|
|
||||||
log.info("Adding 2 new scopes for WHILE {}");
|
|
||||||
|
|
||||||
|
|
||||||
val nodes = graph.getNodeList();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Plan is simple:
|
|
||||||
* 1) we read all declarations of variables used within loop
|
|
||||||
* 2) we set up conditional scope
|
|
||||||
* 3) we set up body scope
|
|
||||||
* 4) ???
|
|
||||||
* 5) PROFIT!
|
|
||||||
*/
|
|
||||||
|
|
||||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
|
||||||
val tfNode = nodes.get(currIndex.get());
|
|
||||||
|
|
||||||
if (!tfNode.getOp().equalsIgnoreCase("enter")) {
|
|
||||||
//skipSet.add(tfNode.getName());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// if (skipSet.contains(tfNode.getName()))
|
|
||||||
// continue;
|
|
||||||
|
|
||||||
skipSet.add(tfNode.getName());
|
|
||||||
|
|
||||||
val vars = new SDVariable[tfNode.getInputCount()];
|
|
||||||
for (int e = 0; e < tfNode.getInputCount(); e++) {
|
|
||||||
val input = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(e));
|
|
||||||
vars[e] = initWith.getVariable(input) == null ? initWith.var(input, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(input);
|
|
||||||
scopeCondition.var(vars[e]);
|
|
||||||
scopeLoop.var(vars[e]);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.inputVars = vars;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// now we're skipping Merge step, since we've already captured variables at Enter step
|
|
||||||
int mergedCnt = 0;
|
|
||||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
|
||||||
val tfNode = nodes.get(currIndex.get());
|
|
||||||
|
|
||||||
if (!tfNode.getOp().equalsIgnoreCase("merge")) {
|
|
||||||
scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), (LongShapeDescriptor) null,new ZeroInitScheme());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
skipSet.add(tfNode.getName());
|
|
||||||
val var = scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), (LongShapeDescriptor)null,new ZeroInitScheme());
|
|
||||||
scopeCondition.var(var);
|
|
||||||
initWith.var(var);
|
|
||||||
mergedCnt++;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// now, we're adding conditional scope
|
|
||||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
|
||||||
val tfNode = nodes.get(currIndex.get());
|
|
||||||
|
|
||||||
// we're parsing up to condition
|
|
||||||
if (tfNode.getOp().equalsIgnoreCase("LoopCond")) {
|
|
||||||
skipSet.add(tfNode.getName());
|
|
||||||
currIndex.incrementAndGet();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
|
|
||||||
boolean isVar = tfNode.getOp().startsWith("VariableV");
|
|
||||||
boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
|
|
||||||
|
|
||||||
|
|
||||||
if (isConst || isVar || isPlaceholder) {
|
|
||||||
val var = scopeCondition.var(tfNode.getName(), (LongShapeDescriptor) null,new ZeroInitScheme());
|
|
||||||
scopeLoop.var(var);
|
|
||||||
initWith.var(var);
|
|
||||||
log.info("Adding condition var [{}]", var.getVarName());
|
|
||||||
|
|
||||||
}
|
|
||||||
else if(!skipSet.contains(tfNode.getName())) {
|
|
||||||
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
|
|
||||||
func.initFromTensorFlow(tfNode,scopeCondition,nodeDef.getAttrMap(),graph);
|
|
||||||
func.setSameDiff(scopeLoop);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
skipSet.add(tfNode.getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// time to skip some Switch calls
|
|
||||||
int switchCnt = 0;
|
|
||||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
|
||||||
val tfNode = nodes.get(currIndex.get());
|
|
||||||
|
|
||||||
// we're parsing up to condition
|
|
||||||
if (!tfNode.getOp().equalsIgnoreCase("Switch"))
|
|
||||||
break;
|
|
||||||
|
|
||||||
switchCnt++;
|
|
||||||
skipSet.add(tfNode.getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
// now we're parsing Identity step
|
|
||||||
int identityCnt = 0;
|
|
||||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
|
||||||
val tfNode = nodes.get(currIndex.get());
|
|
||||||
|
|
||||||
|
|
||||||
if (!tfNode.getOp().equalsIgnoreCase("Identity")) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
|
|
||||||
func.initFromTensorFlow(tfNode,initWith,nodeDef.getAttrMap(),graph);
|
|
||||||
func.setSameDiff(scopeLoop);
|
|
||||||
|
|
||||||
|
|
||||||
val variables = new SDVariable[tfNode.getInputCount()];
|
|
||||||
for(int i = 0; i < tfNode.getInputCount(); i++) {
|
|
||||||
val testVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
|
|
||||||
if(testVar == null) {
|
|
||||||
variables[i] = initWith.var(tfNode.getInput(i), (LongShapeDescriptor) null,new ZeroInitScheme());
|
|
||||||
scopeCondition.var(variables[i]);
|
|
||||||
scopeLoop.var(variables[i]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
|
|
||||||
variables[i] = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
|
|
||||||
scopeCondition.var(variables[i]);
|
|
||||||
scopeLoop.var(variables[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
scopeLoop.addArgsFor(variables,func);
|
|
||||||
skipSet.add(tfNode.getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// parsing body scope
|
|
||||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
|
||||||
val tfNode = nodes.get(currIndex.get());
|
|
||||||
|
|
||||||
if (skipSet.contains(tfNode.getName())) {
|
|
||||||
log.info("Skipping: {}", tfNode.getName());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tfNode.getOp().equalsIgnoreCase("NextIteration")) {
|
|
||||||
// skipSet.add(tfNode.getName());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (skipSet.contains(tfNode.getName())) {
|
|
||||||
log.info("Skipping: {}", tfNode.getName());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
|
|
||||||
boolean isVar = tfNode.getOp().startsWith("VariableV");
|
|
||||||
boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
|
|
||||||
|
|
||||||
|
|
||||||
if (isConst || isVar || isPlaceholder) {
|
|
||||||
val var = scopeLoop.var(tfNode.getName(), (LongShapeDescriptor) null,new ZeroInitScheme());
|
|
||||||
log.info("Adding body var [{}]",var.getVarName());
|
|
||||||
|
|
||||||
} else {
|
|
||||||
log.info("starting on [{}]: {}", tfNode.getName(), tfNode.getOp());
|
|
||||||
|
|
||||||
if (tfNode.getOp().equalsIgnoreCase("enter")) {
|
|
||||||
log.info("NEW LOOP ----------------------------------------");
|
|
||||||
val func = new While(currIndex);
|
|
||||||
func.doImport(nodeDef,initWith,attributesForNode,graph,skipSet,currIndex);
|
|
||||||
func.setSameDiff(initWith);
|
|
||||||
log.info("END LOOP ----------------------------------------");
|
|
||||||
} else {
|
|
||||||
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
|
|
||||||
|
|
||||||
func.initFromTensorFlow(tfNode,initWith,nodeDef.getAttrMap(),graph);
|
|
||||||
|
|
||||||
|
|
||||||
func.setSameDiff(scopeCondition);
|
|
||||||
|
|
||||||
val variables = new SDVariable[tfNode.getInputCount()];
|
|
||||||
for(int i = 0; i < tfNode.getInputCount(); i++) {
|
|
||||||
val name = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i));
|
|
||||||
variables[i] = scopeCondition.getVariable(name);
|
|
||||||
if(variables[i] == null) {
|
|
||||||
if(scopeLoop.getVariable(name) == null)
|
|
||||||
variables[i] = scopeCondition.var(initWith.getVariable(name));
|
|
||||||
else if(scopeLoop.getVariable(name) != null)
|
|
||||||
variables[i] = scopeLoop.getVariable(name);
|
|
||||||
else
|
|
||||||
variables[i] = scopeLoop.var(name, Nd4j.scalar(1.0));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
scopeLoop.addArgsFor(variables,func);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
skipSet.add(tfNode.getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
val returnInputs = new ArrayList<SDVariable>();
|
|
||||||
val returnOutputs = new ArrayList<SDVariable>();
|
|
||||||
|
|
||||||
// mapping NextIterations, to Return op
|
|
||||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
|
||||||
val tfNode = nodes.get(currIndex.get());
|
|
||||||
|
|
||||||
if (!tfNode.getOp().equalsIgnoreCase("NextIteration"))
|
|
||||||
break;
|
|
||||||
|
|
||||||
skipSet.add(tfNode.getName());
|
|
||||||
|
|
||||||
val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
|
|
||||||
val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(inputName) ;
|
|
||||||
returnInputs.add(input);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
this.outputVars = returnOutputs.toArray(new SDVariable[returnOutputs.size()]);
|
|
||||||
this.inputVars = returnInputs.toArray(new SDVariable[returnInputs.size()]);
|
|
||||||
initWith.addArgsFor(inputVars,this);
|
|
||||||
initWith.addOutgoingFor(outputVars,this);
|
|
||||||
|
|
||||||
// we should also map While/Exit to libnd4j while
|
|
||||||
int exitCnt = 0;
|
|
||||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
|
||||||
val tfNode = nodes.get(currIndex.get());
|
|
||||||
|
|
||||||
if (!tfNode.getOp().equalsIgnoreCase("Exit")) {
|
|
||||||
//skipSet.add(tfNode.getName());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
skipSet.add(tfNode.getName());
|
|
||||||
val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
|
|
||||||
val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(inputName) ;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//the output of the condition should always be a singular scalar
|
|
||||||
//this is a safe assumption
|
|
||||||
val conditionVars = scopeCondition.ops();
|
|
||||||
if(conditionVars.length < 1) {
|
|
||||||
throw new ND4JIllegalArgumentException("No functions found!");
|
|
||||||
}
|
|
||||||
this.targetBoolean = conditionVars[conditionVars.length - 1].outputVariables()[0];
|
|
||||||
|
|
||||||
log.info("-------------------------------------------");
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return opName();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "while";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long opHash() {
|
|
||||||
return opName().hashCode();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isInplaceCall() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray[] outputArguments() {
|
|
||||||
return new INDArray[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray[] inputArguments() {
|
|
||||||
return new INDArray[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long[] iArgs() {
|
|
||||||
return new long[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] tArgs() {
|
|
||||||
return new double[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addIArgument(int... arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addIArgument(long... arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeIArgument(Integer arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Long getIArgument(int index) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numIArguments() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addTArgument(double... arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeTArgument(Double arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Double getTArgument(int index) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numTArguments() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numBArguments() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addInputArgument(INDArray... arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeInputArgument(INDArray arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean[] bArgs() {
|
|
||||||
return new boolean[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addBArgument(boolean... arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Boolean getBArgument(int index) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getInputArgument(int index) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numInputArguments() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addOutputArgument(INDArray... arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeOutputArgument(INDArray arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getOutputArgument(int index) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numOutputArguments() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
List<LongShapeDescriptor> ret = new ArrayList<>();
|
|
||||||
for(SDVariable var : args()) {
|
|
||||||
ret.add(sameDiff.getShapeDescriptorForVarName(var.getVarName()));
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public CustomOpDescriptor getDescriptor() {
|
|
||||||
return CustomOpDescriptor.builder().build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void assertValidForExecution() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No *singular (eg: use tensorflowNames() found for this op " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String[] tensorflowNames() {
|
|
||||||
throw new NoOpNameFoundException("This operation has no TF counterpart");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Op.Type opType() {
|
|
||||||
return Op.Type.LOOP;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,96 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.controlflow;
|
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiffConditional;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* While loop derivative
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class WhileDerivative extends While {
|
|
||||||
private While delegate;
|
|
||||||
|
|
||||||
public WhileDerivative(While delegate) {
|
|
||||||
super(delegate);
|
|
||||||
this.delegate = delegate;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiffFunctionDefinition getTrueBody() {
|
|
||||||
return delegate.trueBody;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getTrueBodyName() {
|
|
||||||
return delegate.getTrueBodyName();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiffConditional getPredicate() {
|
|
||||||
return delegate.getPredicate();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiff getPredicateExecution() {
|
|
||||||
return delegate.getPredicateExecution();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SDVariable[] getInputVars() {
|
|
||||||
return delegate.getInputVars();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getBlockName() {
|
|
||||||
return delegate.getBlockName();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiff getLoopBodyExecution() {
|
|
||||||
return delegate.getLoopBodyExecution();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getNumLooped() {
|
|
||||||
return delegate.getNumLooped();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "while_bp";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Op.Type opType() {
|
|
||||||
return Op.Type.CONDITIONAL;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow name for while backprop");
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -55,7 +55,7 @@ public abstract class BaseCompatOp extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -32,9 +32,11 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public class LoopCond extends BaseCompatOp {
|
public class LoopCond extends BaseCompatOp {
|
||||||
|
public static final String OP_NAME = "loop_cond";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "loop_cond";
|
return OP_NAME;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -74,8 +74,6 @@ public class CropAndResize extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
|
||||||
|
|
||||||
String method = attributesForNode.get("method").getS().toStringUtf8();
|
String method = attributesForNode.get("method").getS().toStringUtf8();
|
||||||
if(method.equalsIgnoreCase("nearest")){
|
if(method.equalsIgnoreCase("nearest")){
|
||||||
this.method = Method.NEAREST;
|
this.method = Method.NEAREST;
|
||||||
|
|
|
@ -120,4 +120,10 @@ public class ExtractImagePatches extends DynamicCustomOp {
|
||||||
//TF includes redundant leading and training 1s for kSizes, strides, rates (positions 0/3)
|
//TF includes redundant leading and training 1s for kSizes, strides, rates (positions 0/3)
|
||||||
return new int[]{(int)ilist.getI(1), (int)ilist.getI(2)};
|
return new int[]{(int)ilist.getI(1), (int)ilist.getI(2)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,7 +74,7 @@ public class ResizeBilinear extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
this.alignCorners = attributesForNode.get("align_corners").getB();
|
this.alignCorners = attributesForNode.get("align_corners").getB();
|
||||||
addArgs();
|
addArgs();
|
||||||
|
|
|
@ -50,7 +50,7 @@ public class ResizeNearestNeighbor extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -26,8 +26,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
|
||||||
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -41,7 +39,6 @@ import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
import java.lang.reflect.Field;
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,7 +103,7 @@ public class BatchNorm extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
//Switch order: TF uses [input, gamma, beta, mean, variance]; libnd4j expects [input, mean, variance, gamma, beta]
|
//Switch order: TF uses [input, gamma, beta, mean, variance]; libnd4j expects [input, mean, variance, gamma, beta]
|
||||||
SameDiffOp op = initWith.getOps().get(this.getOwnName());
|
SameDiffOp op = initWith.getOps().get(this.getOwnName());
|
||||||
List<String> list = op.getInputsToOp();
|
List<String> list = op.getInputsToOp();
|
||||||
|
@ -140,8 +137,7 @@ public class BatchNorm extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
|
||||||
addArgs();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -21,33 +21,20 @@ import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
|
||||||
import onnx.Onnx;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
|
||||||
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
|
||||||
import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter;
|
|
||||||
import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter;
|
|
||||||
import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater;
|
|
||||||
import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter;
|
|
||||||
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.tensorflow.framework.AttrValue;
|
|
||||||
import org.tensorflow.framework.GraphDef;
|
|
||||||
import org.tensorflow.framework.NodeDef;
|
|
||||||
|
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
import java.util.*;
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -31,7 +31,6 @@ import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
||||||
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
import org.nd4j.imports.descriptors.properties.adapters.*;
|
import org.nd4j.imports.descriptors.properties.adapters.*;
|
||||||
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -122,7 +121,7 @@ public class Conv2D extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,8 +137,7 @@ public class Conv2D extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
|
||||||
addArgs();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -251,7 +251,7 @@ public class Conv3D extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -198,7 +198,7 @@ public class DeConv2D extends DynamicCustomOp {
|
||||||
val args = args();
|
val args = args();
|
||||||
INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr();
|
INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr();
|
||||||
if (arr == null) {
|
if (arr == null) {
|
||||||
arr = TFGraphMapper.getInstance().getNDArrayFromTensor(nodeDef.getInput(0), nodeDef, graph);
|
arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
|
||||||
// TODO: arguable. it might be easier to permute weights once
|
// TODO: arguable. it might be easier to permute weights once
|
||||||
//arr = (arr.permute(3, 2, 0, 1).dup('c'));
|
//arr = (arr.permute(3, 2, 0, 1).dup('c'));
|
||||||
val varForOp = initWith.getVariable(args[1].getVarName());
|
val varForOp = initWith.getVariable(args[1].getVarName());
|
||||||
|
|
|
@ -214,7 +214,7 @@ public class DeConv2DTF extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -240,9 +240,9 @@ public class DeConv2DTF extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ //inShape, weights, input
|
||||||
int n = args().length;
|
int n = args().length;
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
return Collections.singletonList(inputDataTypes.get(0));
|
return Collections.singletonList(inputDataTypes.get(2));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -160,7 +160,7 @@ public class DeConv3D extends DynamicCustomOp {
|
||||||
val args = args();
|
val args = args();
|
||||||
INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr();
|
INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr();
|
||||||
if (arr == null) {
|
if (arr == null) {
|
||||||
arr = TFGraphMapper.getInstance().getNDArrayFromTensor(nodeDef.getInput(0), nodeDef, graph);
|
arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
|
||||||
val varForOp = initWith.getVariable(args[1].getVarName());
|
val varForOp = initWith.getVariable(args[1].getVarName());
|
||||||
if (arr != null)
|
if (arr != null)
|
||||||
initWith.associateArrayWithVariable(arr, varForOp);
|
initWith.associateArrayWithVariable(arr, varForOp);
|
||||||
|
|
|
@ -77,7 +77,7 @@ public class DepthToSpace extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
boolean isNHWC = dataFormat.equals("NHWC");
|
boolean isNHWC = dataFormat.equals("NHWC");
|
||||||
addIArgument(blockSize, isNHWC ? 1 : 0);
|
addIArgument(blockSize, isNHWC ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,14 +29,15 @@ import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
||||||
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
import org.nd4j.imports.descriptors.properties.adapters.*;
|
import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter;
|
||||||
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
import org.nd4j.imports.descriptors.properties.adapters.NDArrayShapeAdapter;
|
||||||
|
import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater;
|
||||||
|
import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -136,7 +137,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
addArgs();
|
addArgs();
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -162,8 +163,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
|
||||||
addArgs();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -75,7 +75,7 @@ public class SpaceToDepth extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
boolean isNHWC = dataFormat == null ? true : dataFormat.equals("NHWC");
|
boolean isNHWC = dataFormat == null ? true : dataFormat.equals("NHWC");
|
||||||
addIArgument(blockSize, isNHWC ? 1 : 0);
|
addIArgument(blockSize, isNHWC ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,7 +64,7 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
//Switch order: TF uses [logits, labels]; libnd4j expects [labels, logits]
|
//Switch order: TF uses [logits, labels]; libnd4j expects [labels, logits]
|
||||||
SameDiffOp op = initWith.getOps().get(this.getOwnName());
|
SameDiffOp op = initWith.getOps().get(this.getOwnName());
|
||||||
|
|
|
@ -64,7 +64,7 @@ public class Moments extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -60,7 +60,7 @@ public class NormalizeMoments extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -63,7 +63,7 @@ public class ScatterAdd extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
if (nodeDef.containsAttr("use_locking")) {
|
if (nodeDef.containsAttr("use_locking")) {
|
||||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||||
|
|
|
@ -86,7 +86,7 @@ public class ScatterDiv extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
if (nodeDef.containsAttr("use_locking")) {
|
if (nodeDef.containsAttr("use_locking")) {
|
||||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||||
|
|
|
@ -60,7 +60,7 @@ public class ScatterMax extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
if (nodeDef.containsAttr("use_locking")) {
|
if (nodeDef.containsAttr("use_locking")) {
|
||||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||||
|
|
|
@ -60,7 +60,7 @@ public class ScatterMin extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
if (nodeDef.containsAttr("use_locking")) {
|
if (nodeDef.containsAttr("use_locking")) {
|
||||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||||
|
|
|
@ -62,7 +62,7 @@ public class ScatterMul extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
if (nodeDef.containsAttr("use_locking")) {
|
if (nodeDef.containsAttr("use_locking")) {
|
||||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||||
|
|
|
@ -67,7 +67,7 @@ public class ScatterNd extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
if (nodeDef.containsAttr("use_locking")) {
|
if (nodeDef.containsAttr("use_locking")) {
|
||||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||||
|
@ -80,8 +80,8 @@ public class ScatterNd extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ //Indices, updates, shape
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||||
return Collections.singletonList(inputDataTypes.get(1));
|
return Collections.singletonList(inputDataTypes.get(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ public class ScatterNdAdd extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
if (nodeDef.containsAttr("use_locking")) {
|
if (nodeDef.containsAttr("use_locking")) {
|
||||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||||
|
|
|
@ -66,7 +66,7 @@ public class ScatterNdSub extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
if (nodeDef.containsAttr("use_locking")) {
|
if (nodeDef.containsAttr("use_locking")) {
|
||||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||||
|
|
|
@ -66,7 +66,7 @@ public class ScatterNdUpdate extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
if (nodeDef.containsAttr("use_locking")) {
|
if (nodeDef.containsAttr("use_locking")) {
|
||||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||||
|
|
|
@ -79,7 +79,7 @@ public class ScatterSub extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
if (nodeDef.containsAttr("use_locking")) {
|
if (nodeDef.containsAttr("use_locking")) {
|
||||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue