SameDiff execution, TF and memory management overhaul (#10)
* SameDiff execution memory management improvements, round 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clear node outputs closed array references; Slight change to OpValidation internals to not rely on cached op outputs Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next step Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * More polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add WeakIdentityHashmap Signed-off-by: AlexDBlack <blacka101@gmail.com> * Session fixes for control ops and next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * First steps for training session + in-line updating Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix losses and history during training Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * BiasAdd and other fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Don't use SDVariable.getArr() in TFGraphTestAllHelper (import tests) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * First steps for new dependency tracking approach Signed-off-by: AlexDBlack <blacka101@gmail.com> * Start integrating dependency tracking for memory management Signed-off-by: AlexDBlack <blacka101@gmail.com> * Non-control op dependency tracking works/passes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Switch/merge Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup and next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue dependency tracking for initial variables/constants Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add check for aliases when determining if safe to close array Signed-off-by: AlexDBlack <blacka101@gmail.com> * First pass on new TF graph import class Signed-off-by: AlexDBlack <blacka101@gmail.com> * Import fixes, op fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup and fixes for new TF import mapper Signed-off-by: AlexDBlack <blacka101@gmail.com> * More cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Partial implementation of new dependency tracker Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * AbstractDependencyTracker for shared code Signed-off-by: AlexDBlack <blacka101@gmail.com> * Overhaul SameDiff graph execution (dependency tracking) Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes, cleanup, next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Ad no-op memory manager, cleanup, fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix switch dependency tracking Signed-off-by: AlexDBlack <blacka101@gmail.com> * INDArray.toString: no exception on closed arrays, just note closed Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix enter and exit dependency tracking Signed-off-by: AlexDBlack <blacka101@gmail.com> * TensorArray memory management fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add unique ID for INDArray instances Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix memory management for NextIteration outputs in multi-iteration loops Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove (now unnecessary) special case handling for nested enters Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Handle control dependencies during execution; javadoc for memory managers Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup, polish, code comments, javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup and more javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add memory validation for all TF import tests - ensure all arrays (except outputs) are released Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up arrays waiting on unexecuted ops at the end of execution Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes for enter op memory managent in the context of multiple non-nested loops/frames Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix order of operation issues for dependency tracker Signed-off-by: AlexDBlack <blacka101@gmail.com> * Always clear op fields after execution to avoid leaks or unintended array reuse Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Re-implement dtype conversion Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for control dependencies execution (dependency tracking) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix TF import overrides and filtering Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for constant enter array dependency tracking Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J Fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * More DL4J fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup and polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * More polish and javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com> * More logging level tweaks, small DL4J fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix to DL4J SameDiffLayer Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix empty array deserialization, add extra deserialization checks Signed-off-by: AlexDBlack <blacka101@gmail.com> * FlatBuffers control dep serialization fixes; test serialization as part of all TF import tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Variable control dependencies serialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue with removing inputs for ops Signed-off-by: AlexDBlack <blacka101@gmail.com> * FlatBuffers NDArray deserialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * FlatBuffers NDArray deserialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Final cleanup/polish Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
f31661e13b
commit
3f0b4a2d4c
|
@ -157,8 +157,8 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() {
|
||||
|
||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345));
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
||||
|
||||
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
||||
.seed(12345)
|
||||
|
@ -194,8 +194,9 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void testComputationGraphFrozenLayerParamsAfterBackprop() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4,12345), Nd4j.rand(100, 1, 12345));
|
||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
||||
String frozenBranchName = "B1-";
|
||||
String unfrozenBranchName = "B2-";
|
||||
|
||||
|
@ -254,43 +255,18 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
|||
*/
|
||||
@Test
|
||||
public void testFrozenLayerVsSgd() {
|
||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345));
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
||||
|
||||
MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder()
|
||||
.seed(12345)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.updater(new Sgd(2))
|
||||
.list()
|
||||
.layer(0,
|
||||
new DenseLayer.Builder()
|
||||
.nIn(4)
|
||||
.nOut(3)
|
||||
.build()
|
||||
)
|
||||
.layer(1,
|
||||
new DenseLayer.Builder()
|
||||
.updater(new Sgd(0.0))
|
||||
.biasUpdater(new Sgd(0.0))
|
||||
.nIn(3)
|
||||
.nOut(4)
|
||||
.build()
|
||||
).layer(2,
|
||||
new DenseLayer.Builder()
|
||||
.updater(new Sgd(0.0))
|
||||
.biasUpdater(new Sgd(0.0))
|
||||
.nIn(4)
|
||||
.nOut(2)
|
||||
.build()
|
||||
|
||||
).layer(3,
|
||||
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
|
||||
.updater(new Sgd(0.0))
|
||||
.biasUpdater(new Sgd(0.0))
|
||||
.activation(Activation.TANH)
|
||||
.nIn(2)
|
||||
.nOut(1)
|
||||
.build()
|
||||
)
|
||||
.layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
|
||||
.layer(1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build())
|
||||
.layer(2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build())
|
||||
.layer(3,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build())
|
||||
.build();
|
||||
|
||||
MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder()
|
||||
|
@ -298,36 +274,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
|||
.weightInit(WeightInit.XAVIER)
|
||||
.updater(new Sgd(2))
|
||||
.list()
|
||||
.layer(0,
|
||||
new DenseLayer.Builder()
|
||||
.nIn(4)
|
||||
.nOut(3)
|
||||
.build()
|
||||
)
|
||||
.layer(1,
|
||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||
new DenseLayer.Builder()
|
||||
.nIn(3)
|
||||
.nOut(4)
|
||||
.build()
|
||||
)
|
||||
)
|
||||
.layer(2,
|
||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||
new DenseLayer.Builder()
|
||||
.nIn(4)
|
||||
.nOut(2)
|
||||
.build()
|
||||
)
|
||||
).layer(3,
|
||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
|
||||
.activation(Activation.TANH)
|
||||
.nIn(2)
|
||||
.nOut(1)
|
||||
.build()
|
||||
)
|
||||
)
|
||||
.layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
|
||||
.layer(1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()))
|
||||
.layer(2,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()))
|
||||
.layer(3,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
|
||||
.build();
|
||||
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
|
||||
frozenNetwork.init();
|
||||
|
@ -359,8 +309,8 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void testComputationGraphVsSgd() {
|
||||
|
||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345));
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
||||
String frozenBranchName = "B1-";
|
||||
String unfrozenBranchName = "B2-";
|
||||
|
||||
|
@ -381,71 +331,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
|||
.seed(12345)
|
||||
.graphBuilder()
|
||||
.addInputs("input")
|
||||
.addLayer(initialLayer,
|
||||
new DenseLayer.Builder()
|
||||
.nIn(4)
|
||||
.nOut(4)
|
||||
.build(),
|
||||
"input"
|
||||
)
|
||||
.addLayer(frozenBranchUnfrozenLayer0,
|
||||
new DenseLayer.Builder()
|
||||
.nIn(4)
|
||||
.nOut(3)
|
||||
.build(),
|
||||
initialLayer
|
||||
)
|
||||
.addLayer(frozenBranchFrozenLayer1,
|
||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||
new DenseLayer.Builder()
|
||||
.nIn(3)
|
||||
.nOut(4)
|
||||
.build()
|
||||
),
|
||||
frozenBranchUnfrozenLayer0
|
||||
)
|
||||
.addLayer(initialLayer,new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
|
||||
.addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer)
|
||||
.addLayer(frozenBranchFrozenLayer1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||
new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
|
||||
.addLayer(frozenBranchFrozenLayer2,
|
||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||
new DenseLayer.Builder()
|
||||
.nIn(4)
|
||||
.nOut(2)
|
||||
.build()
|
||||
),
|
||||
frozenBranchFrozenLayer1
|
||||
)
|
||||
.addLayer(unfrozenLayer0,
|
||||
new DenseLayer.Builder()
|
||||
.nIn(4)
|
||||
.nOut(4)
|
||||
.build(),
|
||||
initialLayer
|
||||
)
|
||||
.addLayer(unfrozenLayer1,
|
||||
new DenseLayer.Builder()
|
||||
.nIn(4)
|
||||
.nOut(2)
|
||||
.build(),
|
||||
unfrozenLayer0
|
||||
)
|
||||
.addLayer(unfrozenBranch2,
|
||||
new DenseLayer.Builder()
|
||||
.nIn(2)
|
||||
.nOut(1)
|
||||
.build(),
|
||||
unfrozenLayer1
|
||||
)
|
||||
.addVertex("merge",
|
||||
new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
||||
.addLayer(frozenBranchOutput,
|
||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
|
||||
.activation(Activation.TANH)
|
||||
.nIn(3)
|
||||
.nOut(1)
|
||||
.build()
|
||||
),
|
||||
"merge"
|
||||
)
|
||||
new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
|
||||
.addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
|
||||
.addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
|
||||
.addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
|
||||
.addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
||||
.addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
|
||||
.setOutputs(frozenBranchOutput)
|
||||
.build();
|
||||
|
||||
|
@ -454,73 +352,15 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
|||
.seed(12345)
|
||||
.graphBuilder()
|
||||
.addInputs("input")
|
||||
.addLayer(initialLayer,
|
||||
new DenseLayer.Builder()
|
||||
.nIn(4)
|
||||
.nOut(4)
|
||||
.build(),
|
||||
"input"
|
||||
)
|
||||
.addLayer(frozenBranchUnfrozenLayer0,
|
||||
new DenseLayer.Builder()
|
||||
.nIn(4)
|
||||
.nOut(3)
|
||||
.build(),
|
||||
initialLayer
|
||||
)
|
||||
.addLayer(frozenBranchFrozenLayer1,
|
||||
new DenseLayer.Builder()
|
||||
.updater(new Sgd(0.0))
|
||||
.biasUpdater(new Sgd(0.0))
|
||||
.nIn(3)
|
||||
.nOut(4)
|
||||
.build(),
|
||||
frozenBranchUnfrozenLayer0
|
||||
)
|
||||
.addLayer(frozenBranchFrozenLayer2,
|
||||
new DenseLayer.Builder()
|
||||
.updater(new Sgd(0.0))
|
||||
.biasUpdater(new Sgd(0.0))
|
||||
.nIn(4)
|
||||
.nOut(2)
|
||||
.build()
|
||||
,
|
||||
frozenBranchFrozenLayer1
|
||||
)
|
||||
.addLayer(unfrozenLayer0,
|
||||
new DenseLayer.Builder()
|
||||
.nIn(4)
|
||||
.nOut(4)
|
||||
.build(),
|
||||
initialLayer
|
||||
)
|
||||
.addLayer(unfrozenLayer1,
|
||||
new DenseLayer.Builder()
|
||||
.nIn(4)
|
||||
.nOut(2)
|
||||
.build(),
|
||||
unfrozenLayer0
|
||||
)
|
||||
.addLayer(unfrozenBranch2,
|
||||
new DenseLayer.Builder()
|
||||
.nIn(2)
|
||||
.nOut(1)
|
||||
.build(),
|
||||
unfrozenLayer1
|
||||
)
|
||||
.addVertex("merge",
|
||||
new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
||||
.addLayer(frozenBranchOutput,
|
||||
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
|
||||
.updater(new Sgd(0.0))
|
||||
.biasUpdater(new Sgd(0.0))
|
||||
.activation(Activation.TANH)
|
||||
.nIn(3)
|
||||
.nOut(1)
|
||||
.build()
|
||||
,
|
||||
"merge"
|
||||
)
|
||||
.addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
|
||||
.addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer)
|
||||
.addLayer(frozenBranchFrozenLayer1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build(),frozenBranchUnfrozenLayer0)
|
||||
.addLayer(frozenBranchFrozenLayer2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build(),frozenBranchFrozenLayer1)
|
||||
.addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
|
||||
.addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
|
||||
.addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
|
||||
.addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
||||
.addLayer(frozenBranchOutput,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(3).nOut(1).build(),"merge")
|
||||
.setOutputs(frozenBranchOutput)
|
||||
.build();
|
||||
|
||||
|
|
|
@ -172,8 +172,8 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
|||
Map<String,INDArray> placeholders = new HashMap<>();
|
||||
placeholders.put("input", f);
|
||||
placeholders.put("label", l);
|
||||
sd.exec(placeholders, lossMse.getVarName());
|
||||
INDArray outSd = a1.getArr();
|
||||
Map<String,INDArray> map = sd.output(placeholders, lossMse.getVarName(), a1.getVarName());
|
||||
INDArray outSd = map.get(a1.getVarName());
|
||||
INDArray outDl4j = net.output(f);
|
||||
|
||||
assertEquals(testName, outDl4j, outSd);
|
||||
|
@ -187,7 +187,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
|||
|
||||
//Check score
|
||||
double scoreDl4j = net.score();
|
||||
double scoreSd = lossMse.getArr().getDouble(0) + sd.calcRegularizationScore();
|
||||
double scoreSd = map.get(lossMse.getVarName()).getDouble(0) + sd.calcRegularizationScore();
|
||||
assertEquals(testName, scoreDl4j, scoreSd, 1e-6);
|
||||
|
||||
double lossRegScoreSD = sd.calcRegularizationScore();
|
||||
|
|
|
@ -145,7 +145,7 @@ public class LocallyConnected1D extends SameDiffLayer {
|
|||
val weightsShape = new long[] {outputSize, featureDim, nOut};
|
||||
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
|
||||
if (hasBias) {
|
||||
val biasShape = new long[] {1, nOut};
|
||||
val biasShape = new long[] {nOut};
|
||||
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
|
||||
}
|
||||
}
|
||||
|
@ -200,7 +200,7 @@ public class LocallyConnected1D extends SameDiffLayer {
|
|||
|
||||
if (hasBias) {
|
||||
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
||||
SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b);
|
||||
SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b, true);
|
||||
return activation.asSameDiff("out", sameDiff, biasAddedResult);
|
||||
} else {
|
||||
return activation.asSameDiff("out", sameDiff, result);
|
||||
|
|
|
@ -145,7 +145,7 @@ public class LocallyConnected2D extends SameDiffLayer {
|
|||
val weightsShape = new long[] {outputSize[0] * outputSize[1], featureDim, nOut};
|
||||
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
|
||||
if (hasBias) {
|
||||
val biasShape = new long[] {1, nOut};
|
||||
val biasShape = new long[] {nOut};
|
||||
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
|
||||
}
|
||||
}
|
||||
|
@ -211,7 +211,7 @@ public class LocallyConnected2D extends SameDiffLayer {
|
|||
|
||||
if (hasBias) {
|
||||
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
||||
SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b);
|
||||
SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b, true);
|
||||
return activation.asSameDiff("out", sameDiff, biasAddedResult);
|
||||
} else {
|
||||
return activation.asSameDiff("out", sameDiff, permutedResult);
|
||||
|
|
|
@ -114,7 +114,7 @@ public class MergeVertex extends BaseGraphVertex {
|
|||
}
|
||||
|
||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){
|
||||
return Nd4j.hstack(in);
|
||||
return Nd4j.concat(1, in);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -134,6 +134,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
Gradient g = new DefaultGradient();
|
||||
|
||||
INDArray[] dLdIns;
|
||||
boolean[] noClose = new boolean[getNumInputArrays()];
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||
if(sameDiff == null){
|
||||
doInit();
|
||||
|
@ -167,20 +168,21 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
|
||||
for(String s : inputNames){
|
||||
required.add(sameDiff.getVariable(s).gradient().getVarName());
|
||||
}
|
||||
sameDiff.execBackwards(phMap, required);
|
||||
required.addAll(paramTable.keySet());
|
||||
required.addAll(inputNames);
|
||||
|
||||
Map<String,INDArray> gradsMap = sameDiff.calculateGradients(phMap, required);
|
||||
for(String s : paramTable.keySet() ){
|
||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
||||
INDArray sdGrad = gradsMap.get(s);
|
||||
INDArray dl4jGrad = gradTable.get(s);
|
||||
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
|
||||
sdGrad.close(); //TODO optimize this
|
||||
g.gradientForVariable().put(s, dl4jGrad);
|
||||
}
|
||||
|
||||
|
@ -195,13 +197,18 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
|
||||
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
|
||||
dLdIns[j] = epsilon;
|
||||
noClose[j] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//TODO optimize
|
||||
for( int i=0; i<dLdIns.length; i++ ){
|
||||
INDArray before = dLdIns[i];
|
||||
dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]);
|
||||
if(!noClose[i]){
|
||||
before.close();
|
||||
}
|
||||
}
|
||||
|
||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||
|
|
|
@ -110,7 +110,13 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
sameDiff.clearPlaceholders(true);
|
||||
sameDiff.clearOpInputs();
|
||||
|
||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
||||
INDArray ret = workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
||||
if(!result.isAttached() && result.closeable()) {
|
||||
//May be attached in rare edge case - for identity, or if gradients are passed through from output to input
|
||||
// unchaned, as in identity, add scalar, etc
|
||||
result.close();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -122,6 +128,7 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
Gradient g = new DefaultGradient();
|
||||
|
||||
INDArray dLdIn;
|
||||
boolean noCloseEps = false;
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||
if(sameDiff == null){
|
||||
doInit();
|
||||
|
@ -151,26 +158,25 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
}
|
||||
|
||||
List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1);
|
||||
requiredGrads.add(sameDiff.grad(INPUT_KEY).getVarName());
|
||||
for(String s : paramTable.keySet()){
|
||||
requiredGrads.add(sameDiff.grad(s).getVarName());
|
||||
}
|
||||
requiredGrads.add(INPUT_KEY);
|
||||
requiredGrads.addAll(paramTable.keySet());
|
||||
|
||||
sameDiff.execBackwards(phMap, requiredGrads);
|
||||
Map<String,INDArray> m = sameDiff.calculateGradients(phMap, requiredGrads);
|
||||
for(String s : paramTable.keySet() ){
|
||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
||||
INDArray sdGrad = m.get(s);
|
||||
INDArray dl4jGrad = gradTable.get(s);
|
||||
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
|
||||
g.gradientForVariable().put(s, dl4jGrad);
|
||||
sdGrad.close();
|
||||
}
|
||||
|
||||
SDVariable v = sameDiff.grad(INPUT_KEY);
|
||||
dLdIn = v.getArr();
|
||||
dLdIn = m.get(INPUT_KEY);
|
||||
|
||||
if(dLdIn == null && fn.getGradPlaceholderName().equals(v.getVarName())){
|
||||
if(dLdIn == null && fn.getGradPlaceholderName().equals(INPUT_KEY)){
|
||||
//Edge case with lambda layers like identity: SameDiff doesn't store the placeholders
|
||||
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
|
||||
dLdIn = epsilon;
|
||||
noCloseEps = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -178,7 +184,12 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
sameDiff.clearPlaceholders(true);
|
||||
sameDiff.clearOpInputs();
|
||||
|
||||
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
||||
Pair<Gradient, INDArray> ret = new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
||||
if(!noCloseEps && !dLdIn.isAttached() && dLdIn.closeable()) {
|
||||
//Edge case: identity etc - might just pass gradient array through unchanged
|
||||
dLdIn.close();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**Returns the parameters of the neural network as a flattened row vector
|
||||
|
|
|
@ -106,6 +106,12 @@ public struct FlatNode : IFlatbufferObject
|
|||
#endif
|
||||
public DType[] GetOutputTypesArray() { return __p.__vector_as_array<DType>(38); }
|
||||
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
|
||||
public string ControlDeps(int j) { int o = __p.__offset(42); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||
public int ControlDepsLength { get { int o = __p.__offset(42); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||
public string VarControlDeps(int j) { int o = __p.__offset(44); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||
public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||
public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||
public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||
|
||||
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
|
||||
int id = 0,
|
||||
|
@ -126,9 +132,15 @@ public struct FlatNode : IFlatbufferObject
|
|||
VectorOffset outputNamesOffset = default(VectorOffset),
|
||||
StringOffset opNameOffset = default(StringOffset),
|
||||
VectorOffset outputTypesOffset = default(VectorOffset),
|
||||
Offset<FlatArray> scalarOffset = default(Offset<FlatArray>)) {
|
||||
builder.StartObject(19);
|
||||
Offset<FlatArray> scalarOffset = default(Offset<FlatArray>),
|
||||
VectorOffset controlDepsOffset = default(VectorOffset),
|
||||
VectorOffset varControlDepsOffset = default(VectorOffset),
|
||||
VectorOffset controlDepForOffset = default(VectorOffset)) {
|
||||
builder.StartObject(22);
|
||||
FlatNode.AddOpNum(builder, opNum);
|
||||
FlatNode.AddControlDepFor(builder, controlDepForOffset);
|
||||
FlatNode.AddVarControlDeps(builder, varControlDepsOffset);
|
||||
FlatNode.AddControlDeps(builder, controlDepsOffset);
|
||||
FlatNode.AddScalar(builder, scalarOffset);
|
||||
FlatNode.AddOutputTypes(builder, outputTypesOffset);
|
||||
FlatNode.AddOpName(builder, opNameOffset);
|
||||
|
@ -150,7 +162,7 @@ public struct FlatNode : IFlatbufferObject
|
|||
return FlatNode.EndFlatNode(builder);
|
||||
}
|
||||
|
||||
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(19); }
|
||||
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(22); }
|
||||
public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); }
|
||||
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
||||
public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); }
|
||||
|
@ -200,6 +212,18 @@ public struct FlatNode : IFlatbufferObject
|
|||
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
|
||||
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
|
||||
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
|
||||
public static void AddControlDeps(FlatBufferBuilder builder, VectorOffset controlDepsOffset) { builder.AddOffset(19, controlDepsOffset.Value, 0); }
|
||||
public static VectorOffset CreateControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||
public static VectorOffset CreateControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||
public static void StartControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||
public static void AddVarControlDeps(FlatBufferBuilder builder, VectorOffset varControlDepsOffset) { builder.AddOffset(20, varControlDepsOffset.Value, 0); }
|
||||
public static VectorOffset CreateVarControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||
public static VectorOffset CreateVarControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||
public static void StartVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||
public static void AddControlDepFor(FlatBufferBuilder builder, VectorOffset controlDepForOffset) { builder.AddOffset(21, controlDepForOffset.Value, 0); }
|
||||
public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||
public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||
public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
|
||||
int o = builder.EndObject();
|
||||
return new Offset<FlatNode>(o);
|
||||
|
|
|
@ -66,6 +66,12 @@ public final class FlatNode extends Table {
|
|||
public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); }
|
||||
public FlatArray scalar() { return scalar(new FlatArray()); }
|
||||
public FlatArray scalar(FlatArray obj) { int o = __offset(40); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
|
||||
public String controlDeps(int j) { int o = __offset(42); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int controlDepsLength() { int o = __offset(42); return o != 0 ? __vector_len(o) : 0; }
|
||||
public String varControlDeps(int j) { int o = __offset(44); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
|
||||
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
|
||||
|
||||
public static int createFlatNode(FlatBufferBuilder builder,
|
||||
int id,
|
||||
|
@ -86,9 +92,15 @@ public final class FlatNode extends Table {
|
|||
int outputNamesOffset,
|
||||
int opNameOffset,
|
||||
int outputTypesOffset,
|
||||
int scalarOffset) {
|
||||
builder.startObject(19);
|
||||
int scalarOffset,
|
||||
int controlDepsOffset,
|
||||
int varControlDepsOffset,
|
||||
int controlDepForOffset) {
|
||||
builder.startObject(22);
|
||||
FlatNode.addOpNum(builder, opNum);
|
||||
FlatNode.addControlDepFor(builder, controlDepForOffset);
|
||||
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
|
||||
FlatNode.addControlDeps(builder, controlDepsOffset);
|
||||
FlatNode.addScalar(builder, scalarOffset);
|
||||
FlatNode.addOutputTypes(builder, outputTypesOffset);
|
||||
FlatNode.addOpName(builder, opNameOffset);
|
||||
|
@ -110,7 +122,7 @@ public final class FlatNode extends Table {
|
|||
return FlatNode.endFlatNode(builder);
|
||||
}
|
||||
|
||||
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(19); }
|
||||
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); }
|
||||
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
||||
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
||||
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
||||
|
@ -150,6 +162,15 @@ public final class FlatNode extends Table {
|
|||
public static int createOutputTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
|
||||
public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
|
||||
public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); }
|
||||
public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(19, controlDepsOffset, 0); }
|
||||
public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static void addVarControlDeps(FlatBufferBuilder builder, int varControlDepsOffset) { builder.addOffset(20, varControlDepsOffset, 0); }
|
||||
public static int createVarControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
|
||||
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static int endFlatNode(FlatBufferBuilder builder) {
|
||||
int o = builder.endObject();
|
||||
return o;
|
||||
|
|
|
@ -294,7 +294,52 @@ class FlatNode(object):
|
|||
return obj
|
||||
return None
|
||||
|
||||
def FlatNodeStart(builder): builder.StartObject(19)
|
||||
# FlatNode
|
||||
def ControlDeps(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(42))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# FlatNode
|
||||
def ControlDepsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(42))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# FlatNode
|
||||
def VarControlDeps(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(44))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# FlatNode
|
||||
def VarControlDepsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(44))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# FlatNode
|
||||
def ControlDepFor(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(46))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# FlatNode
|
||||
def ControlDepForLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(46))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
def FlatNodeStart(builder): builder.StartObject(22)
|
||||
def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0)
|
||||
def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||
def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0)
|
||||
|
@ -324,4 +369,10 @@ def FlatNodeAddOpName(builder, opName): builder.PrependUOffsetTRelativeSlot(16,
|
|||
def FlatNodeAddOutputTypes(builder, outputTypes): builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(outputTypes), 0)
|
||||
def FlatNodeStartOutputTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1)
|
||||
def FlatNodeAddScalar(builder, scalar): builder.PrependUOffsetTRelativeSlot(18, flatbuffers.number_types.UOffsetTFlags.py_type(scalar), 0)
|
||||
def FlatNodeAddControlDeps(builder, controlDeps): builder.PrependUOffsetTRelativeSlot(19, flatbuffers.number_types.UOffsetTFlags.py_type(controlDeps), 0)
|
||||
def FlatNodeStartControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||
def FlatNodeAddVarControlDeps(builder, varControlDeps): builder.PrependUOffsetTRelativeSlot(20, flatbuffers.number_types.UOffsetTFlags.py_type(varControlDeps), 0)
|
||||
def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||
def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0)
|
||||
def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||
def FlatNodeEnd(builder): return builder.EndObject()
|
||||
|
|
|
@ -37,6 +37,12 @@ public struct FlatVariable : IFlatbufferObject
|
|||
public FlatArray? Ndarray { get { int o = __p.__offset(12); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
|
||||
public int Device { get { int o = __p.__offset(14); return o != 0 ? __p.bb.GetInt(o + __p.bb_pos) : (int)0; } }
|
||||
public VarType Variabletype { get { int o = __p.__offset(16); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } }
|
||||
public string ControlDeps(int j) { int o = __p.__offset(18); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||
public int ControlDepsLength { get { int o = __p.__offset(18); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||
public string ControlDepForOp(int j) { int o = __p.__offset(20); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||
public int ControlDepForOpLength { get { int o = __p.__offset(20); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||
public string ControlDepsForVar(int j) { int o = __p.__offset(22); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
|
||||
public int ControlDepsForVarLength { get { int o = __p.__offset(22); return o != 0 ? __p.__vector_len(o) : 0; } }
|
||||
|
||||
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
|
||||
Offset<IntPair> idOffset = default(Offset<IntPair>),
|
||||
|
@ -45,8 +51,14 @@ public struct FlatVariable : IFlatbufferObject
|
|||
VectorOffset shapeOffset = default(VectorOffset),
|
||||
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
|
||||
int device = 0,
|
||||
VarType variabletype = VarType.VARIABLE) {
|
||||
builder.StartObject(7);
|
||||
VarType variabletype = VarType.VARIABLE,
|
||||
VectorOffset controlDepsOffset = default(VectorOffset),
|
||||
VectorOffset controlDepForOpOffset = default(VectorOffset),
|
||||
VectorOffset controlDepsForVarOffset = default(VectorOffset)) {
|
||||
builder.StartObject(10);
|
||||
FlatVariable.AddControlDepsForVar(builder, controlDepsForVarOffset);
|
||||
FlatVariable.AddControlDepForOp(builder, controlDepForOpOffset);
|
||||
FlatVariable.AddControlDeps(builder, controlDepsOffset);
|
||||
FlatVariable.AddDevice(builder, device);
|
||||
FlatVariable.AddNdarray(builder, ndarrayOffset);
|
||||
FlatVariable.AddShape(builder, shapeOffset);
|
||||
|
@ -57,7 +69,7 @@ public struct FlatVariable : IFlatbufferObject
|
|||
return FlatVariable.EndFlatVariable(builder);
|
||||
}
|
||||
|
||||
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); }
|
||||
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(10); }
|
||||
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
|
||||
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
|
||||
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
|
||||
|
@ -68,6 +80,18 @@ public struct FlatVariable : IFlatbufferObject
|
|||
public static void AddNdarray(FlatBufferBuilder builder, Offset<FlatArray> ndarrayOffset) { builder.AddOffset(4, ndarrayOffset.Value, 0); }
|
||||
public static void AddDevice(FlatBufferBuilder builder, int device) { builder.AddInt(5, device, 0); }
|
||||
public static void AddVariabletype(FlatBufferBuilder builder, VarType variabletype) { builder.AddSbyte(6, (sbyte)variabletype, 0); }
|
||||
public static void AddControlDeps(FlatBufferBuilder builder, VectorOffset controlDepsOffset) { builder.AddOffset(7, controlDepsOffset.Value, 0); }
|
||||
public static VectorOffset CreateControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||
public static VectorOffset CreateControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||
public static void StartControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||
public static void AddControlDepForOp(FlatBufferBuilder builder, VectorOffset controlDepForOpOffset) { builder.AddOffset(8, controlDepForOpOffset.Value, 0); }
|
||||
public static VectorOffset CreateControlDepForOpVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||
public static VectorOffset CreateControlDepForOpVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||
public static void StartControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||
public static void AddControlDepsForVar(FlatBufferBuilder builder, VectorOffset controlDepsForVarOffset) { builder.AddOffset(9, controlDepsForVarOffset.Value, 0); }
|
||||
public static VectorOffset CreateControlDepsForVarVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
|
||||
public static VectorOffset CreateControlDepsForVarVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
|
||||
public static void StartControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
|
||||
public static Offset<FlatVariable> EndFlatVariable(FlatBufferBuilder builder) {
|
||||
int o = builder.EndObject();
|
||||
return new Offset<FlatVariable>(o);
|
||||
|
|
|
@ -28,6 +28,12 @@ public final class FlatVariable extends Table {
|
|||
public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
|
||||
public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
|
||||
public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; }
|
||||
public String controlDeps(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int controlDepsLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; }
|
||||
public String controlDepForOp(int j) { int o = __offset(20); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int controlDepForOpLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; }
|
||||
public String controlDepsForVar(int j) { int o = __offset(22); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int controlDepsForVarLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; }
|
||||
|
||||
public static int createFlatVariable(FlatBufferBuilder builder,
|
||||
int idOffset,
|
||||
|
@ -36,8 +42,14 @@ public final class FlatVariable extends Table {
|
|||
int shapeOffset,
|
||||
int ndarrayOffset,
|
||||
int device,
|
||||
byte variabletype) {
|
||||
builder.startObject(7);
|
||||
byte variabletype,
|
||||
int controlDepsOffset,
|
||||
int controlDepForOpOffset,
|
||||
int controlDepsForVarOffset) {
|
||||
builder.startObject(10);
|
||||
FlatVariable.addControlDepsForVar(builder, controlDepsForVarOffset);
|
||||
FlatVariable.addControlDepForOp(builder, controlDepForOpOffset);
|
||||
FlatVariable.addControlDeps(builder, controlDepsOffset);
|
||||
FlatVariable.addDevice(builder, device);
|
||||
FlatVariable.addNdarray(builder, ndarrayOffset);
|
||||
FlatVariable.addShape(builder, shapeOffset);
|
||||
|
@ -48,7 +60,7 @@ public final class FlatVariable extends Table {
|
|||
return FlatVariable.endFlatVariable(builder);
|
||||
}
|
||||
|
||||
public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(7); }
|
||||
public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(10); }
|
||||
public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); }
|
||||
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
||||
public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); }
|
||||
|
@ -58,6 +70,15 @@ public final class FlatVariable extends Table {
|
|||
public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); }
|
||||
public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); }
|
||||
public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); }
|
||||
public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(7, controlDepsOffset, 0); }
|
||||
public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static void addControlDepForOp(FlatBufferBuilder builder, int controlDepForOpOffset) { builder.addOffset(8, controlDepForOpOffset, 0); }
|
||||
public static int createControlDepForOpVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static void addControlDepsForVar(FlatBufferBuilder builder, int controlDepsForVarOffset) { builder.addOffset(9, controlDepsForVarOffset, 0); }
|
||||
public static int createControlDepsForVarVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static int endFlatVariable(FlatBufferBuilder builder) {
|
||||
int o = builder.endObject();
|
||||
return o;
|
||||
|
|
|
@ -90,7 +90,52 @@ class FlatVariable(object):
|
|||
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
|
||||
return 0
|
||||
|
||||
def FlatVariableStart(builder): builder.StartObject(7)
|
||||
# FlatVariable
|
||||
def ControlDeps(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# FlatVariable
|
||||
def ControlDepsLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# FlatVariable
|
||||
def ControlDepForOp(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# FlatVariable
|
||||
def ControlDepForOpLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# FlatVariable
|
||||
def ControlDepsForVar(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
||||
return ""
|
||||
|
||||
# FlatVariable
|
||||
def ControlDepsForVarLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
def FlatVariableStart(builder): builder.StartObject(10)
|
||||
def FlatVariableAddId(builder, id): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(id), 0)
|
||||
def FlatVariableAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
|
||||
def FlatVariableAddDtype(builder, dtype): builder.PrependInt8Slot(2, dtype, 0)
|
||||
|
@ -99,4 +144,10 @@ def FlatVariableStartShapeVector(builder, numElems): return builder.StartVector(
|
|||
def FlatVariableAddNdarray(builder, ndarray): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(ndarray), 0)
|
||||
def FlatVariableAddDevice(builder, device): builder.PrependInt32Slot(5, device, 0)
|
||||
def FlatVariableAddVariabletype(builder, variabletype): builder.PrependInt8Slot(6, variabletype, 0)
|
||||
def FlatVariableAddControlDeps(builder, controlDeps): builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(controlDeps), 0)
|
||||
def FlatVariableStartControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||
def FlatVariableAddControlDepForOp(builder, controlDepForOp): builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepForOp), 0)
|
||||
def FlatVariableStartControlDepForOpVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||
def FlatVariableAddControlDepsForVar(builder, controlDepsForVar): builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepsForVar), 0)
|
||||
def FlatVariableStartControlDepsForVarVector(builder, numElems): return builder.StartVector(4, numElems, 4)
|
||||
def FlatVariableEnd(builder): return builder.EndObject()
|
||||
|
|
|
@ -35,7 +35,10 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
VT_OUTPUTNAMES = 34,
|
||||
VT_OPNAME = 36,
|
||||
VT_OUTPUTTYPES = 38,
|
||||
VT_SCALAR = 40
|
||||
VT_SCALAR = 40,
|
||||
VT_CONTROLDEPS = 42,
|
||||
VT_VARCONTROLDEPS = 44,
|
||||
VT_CONTROLDEPFOR = 46
|
||||
};
|
||||
int32_t id() const {
|
||||
return GetField<int32_t>(VT_ID, 0);
|
||||
|
@ -94,6 +97,15 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
const FlatArray *scalar() const {
|
||||
return GetPointer<const FlatArray *>(VT_SCALAR);
|
||||
}
|
||||
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps() const {
|
||||
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPS);
|
||||
}
|
||||
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps() const {
|
||||
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_VARCONTROLDEPS);
|
||||
}
|
||||
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor() const {
|
||||
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOR);
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyField<int32_t>(verifier, VT_ID) &&
|
||||
|
@ -132,6 +144,15 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
verifier.VerifyVector(outputTypes()) &&
|
||||
VerifyOffset(verifier, VT_SCALAR) &&
|
||||
verifier.VerifyTable(scalar()) &&
|
||||
VerifyOffset(verifier, VT_CONTROLDEPS) &&
|
||||
verifier.VerifyVector(controlDeps()) &&
|
||||
verifier.VerifyVectorOfStrings(controlDeps()) &&
|
||||
VerifyOffset(verifier, VT_VARCONTROLDEPS) &&
|
||||
verifier.VerifyVector(varControlDeps()) &&
|
||||
verifier.VerifyVectorOfStrings(varControlDeps()) &&
|
||||
VerifyOffset(verifier, VT_CONTROLDEPFOR) &&
|
||||
verifier.VerifyVector(controlDepFor()) &&
|
||||
verifier.VerifyVectorOfStrings(controlDepFor()) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
};
|
||||
|
@ -196,6 +217,15 @@ struct FlatNodeBuilder {
|
|||
void add_scalar(flatbuffers::Offset<FlatArray> scalar) {
|
||||
fbb_.AddOffset(FlatNode::VT_SCALAR, scalar);
|
||||
}
|
||||
void add_controlDeps(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps) {
|
||||
fbb_.AddOffset(FlatNode::VT_CONTROLDEPS, controlDeps);
|
||||
}
|
||||
void add_varControlDeps(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps) {
|
||||
fbb_.AddOffset(FlatNode::VT_VARCONTROLDEPS, varControlDeps);
|
||||
}
|
||||
void add_controlDepFor(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor) {
|
||||
fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor);
|
||||
}
|
||||
explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
|
@ -228,9 +258,15 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNode(
|
|||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputNames = 0,
|
||||
flatbuffers::Offset<flatbuffers::String> opName = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int8_t>> outputTypes = 0,
|
||||
flatbuffers::Offset<FlatArray> scalar = 0) {
|
||||
flatbuffers::Offset<FlatArray> scalar = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor = 0) {
|
||||
FlatNodeBuilder builder_(_fbb);
|
||||
builder_.add_opNum(opNum);
|
||||
builder_.add_controlDepFor(controlDepFor);
|
||||
builder_.add_varControlDeps(varControlDeps);
|
||||
builder_.add_controlDeps(controlDeps);
|
||||
builder_.add_scalar(scalar);
|
||||
builder_.add_outputTypes(outputTypes);
|
||||
builder_.add_opName(opName);
|
||||
|
@ -272,7 +308,10 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
|
|||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *outputNames = nullptr,
|
||||
const char *opName = nullptr,
|
||||
const std::vector<int8_t> *outputTypes = nullptr,
|
||||
flatbuffers::Offset<FlatArray> scalar = 0) {
|
||||
flatbuffers::Offset<FlatArray> scalar = 0,
|
||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
|
||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps = nullptr,
|
||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor = nullptr) {
|
||||
return nd4j::graph::CreateFlatNode(
|
||||
_fbb,
|
||||
id,
|
||||
|
@ -293,7 +332,10 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
|
|||
outputNames ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*outputNames) : 0,
|
||||
opName ? _fbb.CreateString(opName) : 0,
|
||||
outputTypes ? _fbb.CreateVector<int8_t>(*outputTypes) : 0,
|
||||
scalar);
|
||||
scalar,
|
||||
controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0,
|
||||
varControlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*varControlDeps) : 0,
|
||||
controlDepFor ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepFor) : 0);
|
||||
}
|
||||
|
||||
inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) {
|
||||
|
|
|
@ -344,11 +344,65 @@ nd4j.graph.FlatNode.prototype.scalar = function(obj) {
|
|||
return offset ? (obj || new nd4j.graph.FlatArray).__init(this.bb.__indirect(this.bb_pos + offset), this.bb) : null;
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {number} index
|
||||
* @param {flatbuffers.Encoding=} optionalEncoding
|
||||
* @returns {string|Uint8Array}
|
||||
*/
|
||||
nd4j.graph.FlatNode.prototype.controlDeps = function(index, optionalEncoding) {
|
||||
var offset = this.bb.__offset(this.bb_pos, 42);
|
||||
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
|
||||
};
|
||||
|
||||
/**
|
||||
* @returns {number}
|
||||
*/
|
||||
nd4j.graph.FlatNode.prototype.controlDepsLength = function() {
|
||||
var offset = this.bb.__offset(this.bb_pos, 42);
|
||||
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {number} index
|
||||
* @param {flatbuffers.Encoding=} optionalEncoding
|
||||
* @returns {string|Uint8Array}
|
||||
*/
|
||||
nd4j.graph.FlatNode.prototype.varControlDeps = function(index, optionalEncoding) {
|
||||
var offset = this.bb.__offset(this.bb_pos, 44);
|
||||
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
|
||||
};
|
||||
|
||||
/**
|
||||
* @returns {number}
|
||||
*/
|
||||
nd4j.graph.FlatNode.prototype.varControlDepsLength = function() {
|
||||
var offset = this.bb.__offset(this.bb_pos, 44);
|
||||
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {number} index
|
||||
* @param {flatbuffers.Encoding=} optionalEncoding
|
||||
* @returns {string|Uint8Array}
|
||||
*/
|
||||
nd4j.graph.FlatNode.prototype.controlDepFor = function(index, optionalEncoding) {
|
||||
var offset = this.bb.__offset(this.bb_pos, 46);
|
||||
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
|
||||
};
|
||||
|
||||
/**
|
||||
* @returns {number}
|
||||
*/
|
||||
nd4j.graph.FlatNode.prototype.controlDepForLength = function() {
|
||||
var offset = this.bb.__offset(this.bb_pos, 46);
|
||||
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
*/
|
||||
nd4j.graph.FlatNode.startFlatNode = function(builder) {
|
||||
builder.startObject(19);
|
||||
builder.startObject(22);
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -713,6 +767,93 @@ nd4j.graph.FlatNode.addScalar = function(builder, scalarOffset) {
|
|||
builder.addFieldOffset(18, scalarOffset, 0);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {flatbuffers.Offset} controlDepsOffset
|
||||
*/
|
||||
nd4j.graph.FlatNode.addControlDeps = function(builder, controlDepsOffset) {
|
||||
builder.addFieldOffset(19, controlDepsOffset, 0);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {Array.<flatbuffers.Offset>} data
|
||||
* @returns {flatbuffers.Offset}
|
||||
*/
|
||||
nd4j.graph.FlatNode.createControlDepsVector = function(builder, data) {
|
||||
builder.startVector(4, data.length, 4);
|
||||
for (var i = data.length - 1; i >= 0; i--) {
|
||||
builder.addOffset(data[i]);
|
||||
}
|
||||
return builder.endVector();
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {number} numElems
|
||||
*/
|
||||
nd4j.graph.FlatNode.startControlDepsVector = function(builder, numElems) {
|
||||
builder.startVector(4, numElems, 4);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {flatbuffers.Offset} varControlDepsOffset
|
||||
*/
|
||||
nd4j.graph.FlatNode.addVarControlDeps = function(builder, varControlDepsOffset) {
|
||||
builder.addFieldOffset(20, varControlDepsOffset, 0);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {Array.<flatbuffers.Offset>} data
|
||||
* @returns {flatbuffers.Offset}
|
||||
*/
|
||||
nd4j.graph.FlatNode.createVarControlDepsVector = function(builder, data) {
|
||||
builder.startVector(4, data.length, 4);
|
||||
for (var i = data.length - 1; i >= 0; i--) {
|
||||
builder.addOffset(data[i]);
|
||||
}
|
||||
return builder.endVector();
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {number} numElems
|
||||
*/
|
||||
nd4j.graph.FlatNode.startVarControlDepsVector = function(builder, numElems) {
|
||||
builder.startVector(4, numElems, 4);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {flatbuffers.Offset} controlDepForOffset
|
||||
*/
|
||||
nd4j.graph.FlatNode.addControlDepFor = function(builder, controlDepForOffset) {
|
||||
builder.addFieldOffset(21, controlDepForOffset, 0);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {Array.<flatbuffers.Offset>} data
|
||||
* @returns {flatbuffers.Offset}
|
||||
*/
|
||||
nd4j.graph.FlatNode.createControlDepForVector = function(builder, data) {
|
||||
builder.startVector(4, data.length, 4);
|
||||
for (var i = data.length - 1; i >= 0; i--) {
|
||||
builder.addOffset(data[i]);
|
||||
}
|
||||
return builder.endVector();
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {number} numElems
|
||||
*/
|
||||
nd4j.graph.FlatNode.startControlDepForVector = function(builder, numElems) {
|
||||
builder.startVector(4, numElems, 4);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @returns {flatbuffers.Offset}
|
||||
|
|
|
@ -57,7 +57,10 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
VT_SHAPE = 10,
|
||||
VT_NDARRAY = 12,
|
||||
VT_DEVICE = 14,
|
||||
VT_VARIABLETYPE = 16
|
||||
VT_VARIABLETYPE = 16,
|
||||
VT_CONTROLDEPS = 18,
|
||||
VT_CONTROLDEPFOROP = 20,
|
||||
VT_CONTROLDEPSFORVAR = 22
|
||||
};
|
||||
const IntPair *id() const {
|
||||
return GetPointer<const IntPair *>(VT_ID);
|
||||
|
@ -80,6 +83,15 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
VarType variabletype() const {
|
||||
return static_cast<VarType>(GetField<int8_t>(VT_VARIABLETYPE, 0));
|
||||
}
|
||||
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps() const {
|
||||
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPS);
|
||||
}
|
||||
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepForOp() const {
|
||||
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOROP);
|
||||
}
|
||||
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepsForVar() const {
|
||||
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPSFORVAR);
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyOffset(verifier, VT_ID) &&
|
||||
|
@ -93,6 +105,15 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
verifier.VerifyTable(ndarray()) &&
|
||||
VerifyField<int32_t>(verifier, VT_DEVICE) &&
|
||||
VerifyField<int8_t>(verifier, VT_VARIABLETYPE) &&
|
||||
VerifyOffset(verifier, VT_CONTROLDEPS) &&
|
||||
verifier.VerifyVector(controlDeps()) &&
|
||||
verifier.VerifyVectorOfStrings(controlDeps()) &&
|
||||
VerifyOffset(verifier, VT_CONTROLDEPFOROP) &&
|
||||
verifier.VerifyVector(controlDepForOp()) &&
|
||||
verifier.VerifyVectorOfStrings(controlDepForOp()) &&
|
||||
VerifyOffset(verifier, VT_CONTROLDEPSFORVAR) &&
|
||||
verifier.VerifyVector(controlDepsForVar()) &&
|
||||
verifier.VerifyVectorOfStrings(controlDepsForVar()) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
};
|
||||
|
@ -121,6 +142,15 @@ struct FlatVariableBuilder {
|
|||
void add_variabletype(VarType variabletype) {
|
||||
fbb_.AddElement<int8_t>(FlatVariable::VT_VARIABLETYPE, static_cast<int8_t>(variabletype), 0);
|
||||
}
|
||||
void add_controlDeps(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps) {
|
||||
fbb_.AddOffset(FlatVariable::VT_CONTROLDEPS, controlDeps);
|
||||
}
|
||||
void add_controlDepForOp(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepForOp) {
|
||||
fbb_.AddOffset(FlatVariable::VT_CONTROLDEPFOROP, controlDepForOp);
|
||||
}
|
||||
void add_controlDepsForVar(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepsForVar) {
|
||||
fbb_.AddOffset(FlatVariable::VT_CONTROLDEPSFORVAR, controlDepsForVar);
|
||||
}
|
||||
explicit FlatVariableBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
|
@ -141,8 +171,14 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariable(
|
|||
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
|
||||
flatbuffers::Offset<FlatArray> ndarray = 0,
|
||||
int32_t device = 0,
|
||||
VarType variabletype = VarType_VARIABLE) {
|
||||
VarType variabletype = VarType_VARIABLE,
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepForOp = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepsForVar = 0) {
|
||||
FlatVariableBuilder builder_(_fbb);
|
||||
builder_.add_controlDepsForVar(controlDepsForVar);
|
||||
builder_.add_controlDepForOp(controlDepForOp);
|
||||
builder_.add_controlDeps(controlDeps);
|
||||
builder_.add_device(device);
|
||||
builder_.add_ndarray(ndarray);
|
||||
builder_.add_shape(shape);
|
||||
|
@ -161,7 +197,10 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
|
|||
const std::vector<int64_t> *shape = nullptr,
|
||||
flatbuffers::Offset<FlatArray> ndarray = 0,
|
||||
int32_t device = 0,
|
||||
VarType variabletype = VarType_VARIABLE) {
|
||||
VarType variabletype = VarType_VARIABLE,
|
||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
|
||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepForOp = nullptr,
|
||||
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepsForVar = nullptr) {
|
||||
return nd4j::graph::CreateFlatVariable(
|
||||
_fbb,
|
||||
id,
|
||||
|
@ -170,7 +209,10 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
|
|||
shape ? _fbb.CreateVector<int64_t>(*shape) : 0,
|
||||
ndarray,
|
||||
device,
|
||||
variabletype);
|
||||
variabletype,
|
||||
controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0,
|
||||
controlDepForOp ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepForOp) : 0,
|
||||
controlDepsForVar ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepsForVar) : 0);
|
||||
}
|
||||
|
||||
inline const nd4j::graph::FlatVariable *GetFlatVariable(const void *buf) {
|
||||
|
|
|
@ -125,11 +125,65 @@ nd4j.graph.FlatVariable.prototype.variabletype = function() {
|
|||
return offset ? /** @type {nd4j.graph.VarType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.VarType.VARIABLE;
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {number} index
|
||||
* @param {flatbuffers.Encoding=} optionalEncoding
|
||||
* @returns {string|Uint8Array}
|
||||
*/
|
||||
nd4j.graph.FlatVariable.prototype.controlDeps = function(index, optionalEncoding) {
|
||||
var offset = this.bb.__offset(this.bb_pos, 18);
|
||||
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
|
||||
};
|
||||
|
||||
/**
|
||||
* @returns {number}
|
||||
*/
|
||||
nd4j.graph.FlatVariable.prototype.controlDepsLength = function() {
|
||||
var offset = this.bb.__offset(this.bb_pos, 18);
|
||||
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {number} index
|
||||
* @param {flatbuffers.Encoding=} optionalEncoding
|
||||
* @returns {string|Uint8Array}
|
||||
*/
|
||||
nd4j.graph.FlatVariable.prototype.controlDepForOp = function(index, optionalEncoding) {
|
||||
var offset = this.bb.__offset(this.bb_pos, 20);
|
||||
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
|
||||
};
|
||||
|
||||
/**
|
||||
* @returns {number}
|
||||
*/
|
||||
nd4j.graph.FlatVariable.prototype.controlDepForOpLength = function() {
|
||||
var offset = this.bb.__offset(this.bb_pos, 20);
|
||||
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {number} index
|
||||
* @param {flatbuffers.Encoding=} optionalEncoding
|
||||
* @returns {string|Uint8Array}
|
||||
*/
|
||||
nd4j.graph.FlatVariable.prototype.controlDepsForVar = function(index, optionalEncoding) {
|
||||
var offset = this.bb.__offset(this.bb_pos, 22);
|
||||
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
|
||||
};
|
||||
|
||||
/**
|
||||
* @returns {number}
|
||||
*/
|
||||
nd4j.graph.FlatVariable.prototype.controlDepsForVarLength = function() {
|
||||
var offset = this.bb.__offset(this.bb_pos, 22);
|
||||
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
*/
|
||||
nd4j.graph.FlatVariable.startFlatVariable = function(builder) {
|
||||
builder.startObject(7);
|
||||
builder.startObject(10);
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -209,6 +263,93 @@ nd4j.graph.FlatVariable.addVariabletype = function(builder, variabletype) {
|
|||
builder.addFieldInt8(6, variabletype, nd4j.graph.VarType.VARIABLE);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {flatbuffers.Offset} controlDepsOffset
|
||||
*/
|
||||
nd4j.graph.FlatVariable.addControlDeps = function(builder, controlDepsOffset) {
|
||||
builder.addFieldOffset(7, controlDepsOffset, 0);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {Array.<flatbuffers.Offset>} data
|
||||
* @returns {flatbuffers.Offset}
|
||||
*/
|
||||
nd4j.graph.FlatVariable.createControlDepsVector = function(builder, data) {
|
||||
builder.startVector(4, data.length, 4);
|
||||
for (var i = data.length - 1; i >= 0; i--) {
|
||||
builder.addOffset(data[i]);
|
||||
}
|
||||
return builder.endVector();
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {number} numElems
|
||||
*/
|
||||
nd4j.graph.FlatVariable.startControlDepsVector = function(builder, numElems) {
|
||||
builder.startVector(4, numElems, 4);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {flatbuffers.Offset} controlDepForOpOffset
|
||||
*/
|
||||
nd4j.graph.FlatVariable.addControlDepForOp = function(builder, controlDepForOpOffset) {
|
||||
builder.addFieldOffset(8, controlDepForOpOffset, 0);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {Array.<flatbuffers.Offset>} data
|
||||
* @returns {flatbuffers.Offset}
|
||||
*/
|
||||
nd4j.graph.FlatVariable.createControlDepForOpVector = function(builder, data) {
|
||||
builder.startVector(4, data.length, 4);
|
||||
for (var i = data.length - 1; i >= 0; i--) {
|
||||
builder.addOffset(data[i]);
|
||||
}
|
||||
return builder.endVector();
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {number} numElems
|
||||
*/
|
||||
nd4j.graph.FlatVariable.startControlDepForOpVector = function(builder, numElems) {
|
||||
builder.startVector(4, numElems, 4);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {flatbuffers.Offset} controlDepsForVarOffset
|
||||
*/
|
||||
nd4j.graph.FlatVariable.addControlDepsForVar = function(builder, controlDepsForVarOffset) {
|
||||
builder.addFieldOffset(9, controlDepsForVarOffset, 0);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {Array.<flatbuffers.Offset>} data
|
||||
* @returns {flatbuffers.Offset}
|
||||
*/
|
||||
nd4j.graph.FlatVariable.createControlDepsForVarVector = function(builder, data) {
|
||||
builder.startVector(4, data.length, 4);
|
||||
for (var i = data.length - 1; i >= 0; i--) {
|
||||
builder.addOffset(data[i]);
|
||||
}
|
||||
return builder.endVector();
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @param {number} numElems
|
||||
*/
|
||||
nd4j.graph.FlatVariable.startControlDepsForVarVector = function(builder, numElems) {
|
||||
builder.startVector(4, numElems, 4);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {flatbuffers.Builder} builder
|
||||
* @returns {flatbuffers.Offset}
|
||||
|
|
|
@ -52,6 +52,12 @@ table FlatNode {
|
|||
|
||||
//Scalar value - used for scalar ops. Should be single value only.
|
||||
scalar:FlatArray;
|
||||
|
||||
//Control dependencies
|
||||
controlDeps:[string];
|
||||
varControlDeps:[string];
|
||||
controlDepFor:[string];
|
||||
|
||||
}
|
||||
|
||||
root_type FlatNode;
|
|
@ -37,6 +37,10 @@ table FlatVariable {
|
|||
|
||||
device:int; // default is -1, which means _auto_
|
||||
variabletype:VarType;
|
||||
|
||||
controlDeps:[string];
|
||||
controlDepForOp:[string];
|
||||
controlDepsForVar:[string];
|
||||
}
|
||||
|
||||
root_type FlatVariable;
|
|
@ -659,7 +659,8 @@ public abstract class DifferentialFunction {
|
|||
if(sameDiff == null)
|
||||
this.ownName = UUID.randomUUID().toString();
|
||||
else {
|
||||
this.ownName = sameDiff.getOpName(opName());
|
||||
String n = sameDiff.getOpName(opName());
|
||||
this.ownName = n;
|
||||
}
|
||||
|
||||
if(sameDiff != null)
|
||||
|
@ -696,30 +697,11 @@ public abstract class DifferentialFunction {
|
|||
}
|
||||
|
||||
@JsonIgnore
|
||||
private INDArray getX() {
|
||||
INDArray ret = sameDiff.getArrForVarName(args()[0].getVarName());
|
||||
return ret;
|
||||
public INDArray getInputArgument(int index){
|
||||
//Subclasses should implement this
|
||||
throw new UnsupportedOperationException("Not implemented");
|
||||
}
|
||||
|
||||
@JsonIgnore
|
||||
private INDArray getY() {
|
||||
if(args().length > 1) {
|
||||
INDArray ret = sameDiff.getArrForVarName(args()[1].getVarName());
|
||||
return ret;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@JsonIgnore
|
||||
private INDArray getZ() {
|
||||
if(isInPlace())
|
||||
return getX();
|
||||
SDVariable opId = outputVariables()[0];
|
||||
INDArray ret = opId.getArr();
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/**
|
||||
|
@ -860,4 +842,8 @@ public abstract class DifferentialFunction {
|
|||
|
||||
public int getNumOutputs(){return -1;}
|
||||
|
||||
/**
|
||||
* Clear the input and output INDArrays, if any are set
|
||||
*/
|
||||
public abstract void clearArrays();
|
||||
}
|
||||
|
|
|
@ -982,8 +982,8 @@ public class DifferentialFunctionFactory {
|
|||
return new CumProdBp(sameDiff(), in, grad, exclusive, reverse, axis).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable biasAdd(SDVariable input, SDVariable bias) {
|
||||
return new BiasAdd(sameDiff(), input, bias).outputVariable();
|
||||
public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) {
|
||||
return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) {
|
||||
|
|
|
@ -24,6 +24,7 @@ import lombok.Getter;
|
|||
import org.nd4j.autodiff.listeners.Listener;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.IMetric;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
@ -319,6 +320,7 @@ public class History {
|
|||
* Gets the training evaluations ran during the last epoch
|
||||
*/
|
||||
public EvaluationRecord finalTrainingEvaluations(){
|
||||
Preconditions.checkState(!trainingHistory.isEmpty(), "Cannot get final training evaluation - history is empty");
|
||||
return trainingHistory.get(trainingHistory.size() - 1);
|
||||
}
|
||||
|
||||
|
@ -326,6 +328,7 @@ public class History {
|
|||
* Gets the validation evaluations ran during the last epoch
|
||||
*/
|
||||
public EvaluationRecord finalValidationEvaluations(){
|
||||
Preconditions.checkState(!validationHistory.isEmpty(), "Cannot get final validation evaluation - history is empty");
|
||||
return validationHistory.get(validationHistory.size() - 1);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,34 +16,23 @@
|
|||
|
||||
package org.nd4j.autodiff.samediff;
|
||||
|
||||
import java.util.Objects;
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.nd4j.weightinit.WeightInitScheme;
|
||||
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -167,6 +156,10 @@ public class SDVariable implements Serializable {
|
|||
if(sameDiff.arrayAlreadyExistsForVarName(getVarName()))
|
||||
return sameDiff.getArrForVarName(getVarName());
|
||||
|
||||
if(variableType == VariableType.ARRAY){
|
||||
throw new UnsupportedOperationException("Cannot get array for ARRAY type SDVariable - use SDVariable.exec or SameDiff.output instead");
|
||||
}
|
||||
|
||||
//initialize value if it's actually a scalar constant (zero or 1 typically...)
|
||||
if(variableType == VariableType.VARIABLE && weightInitScheme != null && shape != null){
|
||||
INDArray arr = weightInitScheme.create(dataType, shape);
|
||||
|
@ -211,8 +204,8 @@ public class SDVariable implements Serializable {
|
|||
* created automatically when training is performed.
|
||||
*/
|
||||
public SDVariable getGradient() {
|
||||
Preconditions.checkState(dataType().isFPType(), "Cannot get gradient of %s variable \"%s\": only floating" +
|
||||
" point variables have gradients", getVarName(), dataType());
|
||||
Preconditions.checkState(dataType().isFPType(), "Cannot get gradient of %s datatype variable \"%s\": only floating" +
|
||||
" point variables have gradients", dataType(), getVarName());
|
||||
return sameDiff.getGradForVariable(getVarName());
|
||||
}
|
||||
|
||||
|
@ -230,7 +223,7 @@ public class SDVariable implements Serializable {
|
|||
}
|
||||
|
||||
long[] initialShape = sameDiff.getShapeForVarName(getVarName());
|
||||
if(initialShape == null) {
|
||||
if(initialShape == null && variableType != VariableType.ARRAY) {
|
||||
val arr = getArr();
|
||||
if(arr != null)
|
||||
return arr.shape();
|
||||
|
@ -254,7 +247,7 @@ public class SDVariable implements Serializable {
|
|||
public DataType dataType() {
|
||||
if(this.dataType == null){
|
||||
//Try to infer datatype instead of returning null
|
||||
if(getArr() != null){
|
||||
if(variableType != VariableType.ARRAY && getArr() != null){
|
||||
this.dataType = getArr().dataType();
|
||||
}
|
||||
}
|
||||
|
@ -1518,26 +1511,59 @@ public class SDVariable implements Serializable {
|
|||
|
||||
/**
|
||||
* Add a control dependency for this variable on the specified variable.<br>
|
||||
* Control depnedencies can be used to enforce the execution order.
|
||||
* Control dependencies can be used to enforce the execution order.
|
||||
* For example, if a control dependency X->Y exists, then Y will only be executed after X is executed - even
|
||||
* if Y wouldn't normally depend on the result/values of X.
|
||||
*
|
||||
* @param controlDependency Control dependency to add for this variable
|
||||
*/
|
||||
public void addControlDependency(SDVariable controlDependency){
|
||||
String cdN = controlDependency.getVarName();
|
||||
String n = this.getVarName();
|
||||
Variable v = sameDiff.getVariables().get(n);
|
||||
if(v.getControlDeps() == null)
|
||||
v.setControlDeps(new ArrayList<String>());
|
||||
if(!v.getControlDeps().contains(cdN))
|
||||
v.getControlDeps().add(cdN);
|
||||
Variable vThis = sameDiff.getVariables().get(getVarName());
|
||||
Variable vCD = sameDiff.getVariables().get(controlDependency.getVarName());
|
||||
|
||||
Variable v2 = sameDiff.getVariables().get(cdN);
|
||||
if(v2.getControlDepsForVar() == null)
|
||||
v2.setControlDepsForVar(new ArrayList<String>());
|
||||
if(!v2.getControlDepsForVar().contains(n))
|
||||
v2.getControlDepsForVar().add(n);
|
||||
//If possible: add control dependency on ops
|
||||
if(vThis.getOutputOfOp() != null && vCD.getOutputOfOp() != null ){
|
||||
//Op -> Op case
|
||||
SameDiffOp oThis = sameDiff.getOps().get(vThis.getOutputOfOp());
|
||||
SameDiffOp oCD = sameDiff.getOps().get(vCD.getOutputOfOp());
|
||||
|
||||
if(oThis.getControlDeps() == null)
|
||||
oThis.setControlDeps(new ArrayList<String>());
|
||||
if(!oThis.getControlDeps().contains(oCD.getName()))
|
||||
oThis.getControlDeps().add(oCD.getName());
|
||||
|
||||
if(oCD.getControlDepFor() == null)
|
||||
oCD.setControlDepFor(new ArrayList<String>());
|
||||
if(!oCD.getControlDepFor().contains(oThis.getName()))
|
||||
oCD.getControlDepFor().add(oThis.getName());
|
||||
} else {
|
||||
if(vThis.getOutputOfOp() != null){
|
||||
//const/ph -> op case
|
||||
SameDiffOp oThis = sameDiff.getOps().get(vThis.getOutputOfOp());
|
||||
|
||||
if(oThis.getVarControlDeps() == null)
|
||||
oThis.setVarControlDeps(new ArrayList<String>());
|
||||
|
||||
if(!oThis.getVarControlDeps().contains(vCD.getName()))
|
||||
oThis.getVarControlDeps().add(vCD.getName());
|
||||
|
||||
if(vCD.getControlDepsForOp() == null)
|
||||
vCD.setControlDepsForOp(new ArrayList<String>());
|
||||
if(!vCD.getControlDepsForOp().contains(oThis.getName()))
|
||||
vCD.getControlDepsForOp().add(oThis.getName());
|
||||
} else {
|
||||
//const/ph -> const/ph case
|
||||
if(vThis.getControlDeps() == null)
|
||||
vThis.setControlDeps(new ArrayList<String>());
|
||||
if(!vThis.getControlDeps().contains(vCD.getName()))
|
||||
vThis.getControlDeps().add(vCD.getName());
|
||||
|
||||
if(vCD.getControlDepsForVar() == null)
|
||||
vCD.setControlDepsForVar(new ArrayList<String>());
|
||||
if(!vCD.getControlDepsForVar().contains(vThis.getName()))
|
||||
vCD.getControlDepsForVar().add(vThis.getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -16,58 +16,16 @@
|
|||
|
||||
package org.nd4j.autodiff.samediff;
|
||||
|
||||
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.BufferedOutputStream;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.lang.reflect.Method;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.IdentityHashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Queue;
|
||||
import java.util.Set;
|
||||
import java.util.Stack;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
|
||||
import org.nd4j.autodiff.execution.conf.OutputMode;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.Listener;
|
||||
import org.nd4j.autodiff.listeners.ListenerResponse;
|
||||
import org.nd4j.autodiff.listeners.Loss;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.listeners.*;
|
||||
import org.nd4j.autodiff.listeners.impl.HistoryListener;
|
||||
import org.nd4j.autodiff.listeners.records.History;
|
||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||
|
@ -75,34 +33,14 @@ import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
|
|||
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
|
||||
import org.nd4j.autodiff.samediff.config.FitConfig;
|
||||
import org.nd4j.autodiff.samediff.config.OutputConfig;
|
||||
import org.nd4j.autodiff.samediff.internal.AbstractSession;
|
||||
import org.nd4j.autodiff.samediff.internal.DataTypesSession;
|
||||
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||
import org.nd4j.autodiff.samediff.ops.SDBaseOps;
|
||||
import org.nd4j.autodiff.samediff.ops.SDBitwise;
|
||||
import org.nd4j.autodiff.samediff.ops.SDCNN;
|
||||
import org.nd4j.autodiff.samediff.ops.SDImage;
|
||||
import org.nd4j.autodiff.samediff.ops.SDLoss;
|
||||
import org.nd4j.autodiff.samediff.ops.SDMath;
|
||||
import org.nd4j.autodiff.samediff.ops.SDNN;
|
||||
import org.nd4j.autodiff.samediff.ops.SDRNN;
|
||||
import org.nd4j.autodiff.samediff.ops.SDRandom;
|
||||
import org.nd4j.autodiff.samediff.internal.*;
|
||||
import org.nd4j.autodiff.samediff.ops.*;
|
||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.ROC;
|
||||
import org.nd4j.graph.ExecutionMode;
|
||||
import org.nd4j.graph.FlatArray;
|
||||
import org.nd4j.graph.FlatConfiguration;
|
||||
import org.nd4j.graph.FlatGraph;
|
||||
import org.nd4j.graph.FlatNode;
|
||||
import org.nd4j.graph.FlatVariable;
|
||||
import org.nd4j.graph.IntPair;
|
||||
import org.nd4j.graph.OpType;
|
||||
import org.nd4j.graph.UpdaterState;
|
||||
import org.nd4j.graph.*;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
|
@ -112,8 +50,6 @@ import org.nd4j.linalg.api.ops.CustomOp;
|
|||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.If;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.While;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
|
||||
|
@ -136,7 +72,6 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.learning.GradientUpdater;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
import org.nd4j.linalg.primitives.AtomicBoolean;
|
||||
import org.nd4j.linalg.primitives.AtomicDouble;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
||||
|
@ -152,6 +87,17 @@ import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
|
|||
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
||||
import java.io.*;
|
||||
import java.lang.reflect.Method;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
|
||||
|
||||
/**
|
||||
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
|
||||
* <p>
|
||||
|
@ -683,7 +629,7 @@ public class SameDiff extends SDBaseOps {
|
|||
for (val var : variables()) {
|
||||
SDVariable clone = var.clone(this);
|
||||
SDVariable newVar = sameDiff.var(clone);
|
||||
if (var.getArr() != null && var.getVariableType() != VariableType.ARRAY) { //ARRAY type = "activations" - are overwritten anyway
|
||||
if (var.getVariableType() != VariableType.ARRAY && var.getArr() != null ) { //ARRAY type = "activations" - are overwritten anyway
|
||||
sameDiff.associateArrayWithVariable(var.getArr(), newVar);
|
||||
}
|
||||
|
||||
|
@ -795,9 +741,9 @@ public class SameDiff extends SDBaseOps {
|
|||
* @param function the function to get the inputs for
|
||||
* @return the input ids for a given function
|
||||
*/
|
||||
public String[] getInputsForOp(DifferentialFunction function) {
|
||||
public String[] getInputsForOp(@NonNull DifferentialFunction function) {
|
||||
if (!ops.containsKey(function.getOwnName()))
|
||||
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
|
||||
throw new ND4JIllegalStateException("Unknown function instance id found: \"" + function.getOwnName() + "\"");
|
||||
List<String> inputs = ops.get(function.getOwnName()).getInputsToOp();
|
||||
return inputs == null ? null : inputs.toArray(new String[inputs.size()]);
|
||||
}
|
||||
|
@ -1102,12 +1048,8 @@ public class SameDiff extends SDBaseOps {
|
|||
constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true));
|
||||
break;
|
||||
case ARRAY:
|
||||
// FIXME: remove this before release
|
||||
val session = sessions.get(Thread.currentThread().getId());
|
||||
val varId = session.newVarId(variable.getVarName(), AbstractSession.OUTER_FRAME, 0, null);
|
||||
session.getNodeOutputs().put(varId, arr);
|
||||
//throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY");
|
||||
break;
|
||||
throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" +
|
||||
" this type of variable is calculated ");
|
||||
case PLACEHOLDER:
|
||||
//Validate placeholder shapes:
|
||||
long[] phShape = variable.placeholderShape();
|
||||
|
@ -2152,11 +2094,32 @@ public class SameDiff extends SDBaseOps {
|
|||
requiredVars.addAll(l.requiredVariables(this).trainingVariables());
|
||||
}
|
||||
|
||||
ArrayList<Listener> listenersWitHistory = new ArrayList<>(listeners);
|
||||
List<Listener> listenersWitHistory = new ArrayList<>(listeners);
|
||||
for(Listener l : this.listeners){
|
||||
if(!listenersWitHistory.contains(l))
|
||||
listenersWitHistory.add(l);
|
||||
}
|
||||
listenersWitHistory.add(history);
|
||||
|
||||
for (int i = 0; i < numEpochs; i++) {
|
||||
|
||||
SameDiff gradInstance = getFunction("grad");
|
||||
if(gradInstance == null){
|
||||
createGradFunction();
|
||||
gradInstance = getFunction("grad");
|
||||
}
|
||||
TrainingSession ts = new TrainingSession(gradInstance);
|
||||
gradInstance.setTrainingConfig(trainingConfig); //In case any listeners want to use it
|
||||
|
||||
Set<String> paramsToTrain = new LinkedHashSet<>();
|
||||
for(Variable v : variables.values()){
|
||||
if(v.getVariable().getVariableType() == VariableType.VARIABLE){
|
||||
//TODO not all variable type are needed - i.e., variable that doesn't impact loss should be skipped
|
||||
paramsToTrain.add(v.getName());
|
||||
}
|
||||
}
|
||||
|
||||
Loss lastLoss = null;
|
||||
for (int i = 0; i < numEpochs; i++) {
|
||||
if (incrementEpochCount && hasListeners) {
|
||||
at.setEpoch(trainingConfig.getEpochCount());
|
||||
for (Listener l : activeListeners) {
|
||||
|
@ -2200,153 +2163,38 @@ public class SameDiff extends SDBaseOps {
|
|||
Preconditions.checkState(placeholders.size() > 0, "No placeholder variables were set for training");
|
||||
resolveVariablesWith(placeholders);
|
||||
|
||||
//Calculate gradients:
|
||||
execBackwards(placeholders, at.operation(), ds, requiredVars, activeListeners);
|
||||
|
||||
|
||||
//Apply updater:
|
||||
//Call TrainingSession to perform training
|
||||
if (!initializedTraining)
|
||||
initializeTraining();
|
||||
|
||||
Map<Class<?>, AtomicDouble> regScore = null; //Holds regularization scores for later reporting to listeners
|
||||
if (hasListeners) {
|
||||
regScore = new HashMap<>();
|
||||
}
|
||||
lastLoss = ts.trainingIteration(
|
||||
trainingConfig,
|
||||
placeholders,
|
||||
paramsToTrain,
|
||||
updaterMap,
|
||||
ds,
|
||||
getLossVariables(),
|
||||
listenersWitHistory,
|
||||
at);
|
||||
|
||||
int iteration = trainingConfig.getIterationCount();
|
||||
int e = trainingConfig.getEpochCount();
|
||||
for (Variable v : variables.values()) {
|
||||
//Only update trainable params - float type parameters (variable type vars)
|
||||
SDVariable sdv = v.getVariable();
|
||||
if (sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType())
|
||||
continue;
|
||||
|
||||
|
||||
INDArray param = sdv.getArr();
|
||||
SDVariable gradVar = sdv.getGradient();
|
||||
if (gradVar == null) {
|
||||
//Not all trainable parameters have gradients defined.
|
||||
//Consider graph: in1->loss1; in2->loss2, where we optimize only loss1.
|
||||
//No gradient will be present for in2, because in2 doesn't impact loss1 at all
|
||||
continue;
|
||||
}
|
||||
INDArray grad = gradVar.getArr();
|
||||
//Note: don't need to divide by minibatch - that should be handled in loss function and hence loss function gradients,
|
||||
// which should flow through to here
|
||||
|
||||
//Pre-apply regularization (L1, L2)
|
||||
List<Regularization> r = trainingConfig.getRegularization();
|
||||
int iterCount = trainingConfig.getIterationCount();
|
||||
int epochCount = trainingConfig.getEpochCount();
|
||||
double lr = trainingConfig.getUpdater().hasLearningRate() ? trainingConfig.getUpdater().getLearningRate(iteration, epochCount) : 1.0;
|
||||
if (r != null && r.size() > 0) {
|
||||
for (Regularization reg : r) {
|
||||
if (reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) {
|
||||
reg.apply(param, grad, lr, iterCount, epochCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Apply updater. Note that we need to reshape to [1,length] for updater
|
||||
INDArray reshapedView = Shape.newShapeNoCopy(grad, new long[]{1, grad.length()}, grad.ordering() == 'f'); //TODO make sure we always reshape in same order!
|
||||
Preconditions.checkState(reshapedView != null, "Error reshaping array for parameter \"%s\": array is a view?", sdv);
|
||||
GradientUpdater u = updaterMap.get(sdv.getVarName());
|
||||
try {
|
||||
u.applyUpdater(reshapedView, iteration, e);
|
||||
} catch (Throwable t) {
|
||||
throw new RuntimeException("Error applying updater " + u.getClass().getSimpleName() + " to parameter \"" + sdv.getVarName()
|
||||
+ "\": either parameter size is inconsistent between iterations, or \"" + sdv.getVarName() + "\" should not be a trainable parameter?", t);
|
||||
}
|
||||
|
||||
//Post-apply regularization (weight decay)
|
||||
if (r != null && r.size() > 0) {
|
||||
for (Regularization reg : r) {
|
||||
if (reg.applyStep() == Regularization.ApplyStep.POST_UPDATER) {
|
||||
reg.apply(param, grad, lr, iterCount, epochCount);
|
||||
if (hasListeners) {
|
||||
double score = reg.score(param, iterCount, epochCount);
|
||||
if (!regScore.containsKey(reg.getClass())) {
|
||||
regScore.put(reg.getClass(), new AtomicDouble());
|
||||
}
|
||||
regScore.get(reg.getClass()).addAndGet(score);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (hasListeners) {
|
||||
for (Listener l : activeListeners) {
|
||||
if (l.isActive(at.operation()))
|
||||
l.preUpdate(this, at, v, reshapedView);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (trainingConfig.isMinimize()) {
|
||||
param.subi(grad);
|
||||
} else {
|
||||
param.addi(grad);
|
||||
}
|
||||
}
|
||||
|
||||
double[] d = new double[lossVariables.size() + regScore.size()];
|
||||
List<String> lossVars;
|
||||
if (regScore.size() > 0) {
|
||||
lossVars = new ArrayList<>(lossVariables.size() + regScore.size());
|
||||
lossVars.addAll(lossVariables);
|
||||
int s = regScore.size();
|
||||
//Collect regularization losses
|
||||
for (Map.Entry<Class<?>, AtomicDouble> entry : regScore.entrySet()) {
|
||||
lossVars.add(entry.getKey().getSimpleName());
|
||||
d[s] = entry.getValue().get();
|
||||
}
|
||||
} else {
|
||||
lossVars = lossVariables;
|
||||
}
|
||||
|
||||
//Collect the losses...
|
||||
SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY);
|
||||
int count = 0;
|
||||
for (String s : lossVariables) {
|
||||
INDArray arr = gradFn.getArrForVarName(s);
|
||||
double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue();
|
||||
d[count++] = l;
|
||||
}
|
||||
|
||||
Loss loss = new Loss(lossVars, d);
|
||||
|
||||
if (lossNames == null) {
|
||||
lossNames = lossVars;
|
||||
} else {
|
||||
Preconditions.checkState(lossNames.equals(lossVars),
|
||||
"Loss names mismatch, expected: %s, got: %s", lossNames, lossVars);
|
||||
}
|
||||
|
||||
if (lossSums == null) {
|
||||
lossSums = d;
|
||||
lossSums = lastLoss.getLosses().clone();
|
||||
} else {
|
||||
Preconditions.checkState(lossNames.equals(lossVars),
|
||||
"Loss size mismatch, expected: %s, got: %s", lossSums.length, d.length);
|
||||
|
||||
for (int j = 0; j < lossSums.length; j++) {
|
||||
lossSums[j] += d[j];
|
||||
lossSums[j] += lastLoss.getLosses()[j];
|
||||
}
|
||||
}
|
||||
lossCount++;
|
||||
|
||||
if (hasListeners) {
|
||||
for (Listener l : activeListeners) {
|
||||
l.iterationDone(this, at, ds, loss);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
trainingConfig.incrementIterationCount();
|
||||
}
|
||||
|
||||
long epochTime = System.currentTimeMillis() - epochStartTime;
|
||||
|
||||
if (incrementEpochCount) {
|
||||
lossNames = lastLoss.getLossNames();
|
||||
|
||||
for (int j = 0; j < lossSums.length; j++)
|
||||
lossSums[j] /= lossCount;
|
||||
|
||||
|
@ -2356,14 +2204,13 @@ public class SameDiff extends SDBaseOps {
|
|||
lossCurve = new LossCurve(lossSums, lossNames);
|
||||
}
|
||||
|
||||
|
||||
if (incrementEpochCount) {
|
||||
if (hasListeners) {
|
||||
|
||||
boolean doStop = false;
|
||||
Listener stopped = null;
|
||||
|
||||
for (Listener l : activeListeners) {
|
||||
|
||||
ListenerResponse res = l.epochEnd(this, at, lossCurve, epochTime);
|
||||
|
||||
if (res == ListenerResponse.STOP && (i < numEpochs - 1)) {
|
||||
|
@ -2431,7 +2278,6 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
trainingConfig.incrementEpochCount();
|
||||
}
|
||||
|
||||
if (i < numEpochs - 1) {
|
||||
iter.reset();
|
||||
}
|
||||
|
@ -2507,7 +2353,9 @@ public class SameDiff extends SDBaseOps {
|
|||
INDArray arr = v.getVariable().getArr();
|
||||
long stateSize = trainingConfig.getUpdater().stateSize(arr.length());
|
||||
INDArray view = stateSize == 0 ? null : Nd4j.createUninitialized(arr.dataType(), 1, stateSize);
|
||||
updaterMap.put(v.getName(), trainingConfig.getUpdater().instantiate(view, true));
|
||||
GradientUpdater gu = trainingConfig.getUpdater().instantiate(view, false);
|
||||
gu.setStateViewArray(view, arr.shape(), arr.ordering(), true);
|
||||
updaterMap.put(v.getName(), gu);
|
||||
}
|
||||
|
||||
initializedTraining = true;
|
||||
|
@ -3862,7 +3710,8 @@ public class SameDiff extends SDBaseOps {
|
|||
long thisSize = trainingConfig.getUpdater().stateSize(arr.length());
|
||||
if (thisSize > 0) {
|
||||
INDArray stateArr = Nd4j.create(arr.dataType(), 1, thisSize);
|
||||
GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, true);
|
||||
GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, false);
|
||||
u.setStateViewArray(stateArr, arr.shape(), arr.ordering(), true); //TODO eventually this should be 1 call...
|
||||
updaterMap.put(v.getVarName(), u);
|
||||
} else {
|
||||
GradientUpdater u = trainingConfig.getUpdater().instantiate((INDArray) null, true);
|
||||
|
@ -3946,7 +3795,53 @@ public class SameDiff extends SDBaseOps {
|
|||
sessions.clear();
|
||||
|
||||
//Recalculate datatypes of outputs, and dynamically update them
|
||||
calculateOutputDataTypes(true);
|
||||
Set<String> allSeenOps = new HashSet<>();
|
||||
Queue<String> queueOps = new LinkedList<>();
|
||||
|
||||
for(String s : dataTypeMap.keySet()){
|
||||
Variable v = variables.get(s);
|
||||
v.getVariable().setDataType(dataTypeMap.get(s));
|
||||
List<String> inToOp = v.getInputsForOp();
|
||||
if(inToOp != null){
|
||||
for(String op : inToOp) {
|
||||
if (!allSeenOps.contains(op)) {
|
||||
allSeenOps.add(op);
|
||||
queueOps.add(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while(!queueOps.isEmpty()){
|
||||
String op = queueOps.remove();
|
||||
SameDiffOp o = ops.get(op);
|
||||
List<String> inVars = o.getInputsToOp();
|
||||
List<DataType> inDTypes = new ArrayList<>();
|
||||
if(inVars != null) {
|
||||
for (String s : inVars) {
|
||||
SDVariable v = variables.get(s).getVariable();
|
||||
inDTypes.add(v.dataType());
|
||||
}
|
||||
}
|
||||
List<DataType> outDtypes = o.getOp().calculateOutputDataTypes(inDTypes);
|
||||
List<String> outVars = o.getOutputsOfOp();
|
||||
for( int i=0; i<outVars.size(); i++ ){
|
||||
String varName = outVars.get(i);
|
||||
Variable var = variables.get(varName);
|
||||
SDVariable v = var.getVariable();
|
||||
v.setDataType(outDtypes.get(i));
|
||||
|
||||
//Also update queue
|
||||
if(var.getInputsForOp() != null){
|
||||
for(String opName : var.getInputsForOp()){
|
||||
if(!allSeenOps.contains(opName)){
|
||||
allSeenOps.add(opName);
|
||||
queueOps.add(opName);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4097,6 +3992,8 @@ public class SameDiff extends SDBaseOps {
|
|||
break;
|
||||
}
|
||||
}
|
||||
|
||||
variables.get(varName).getInputsForOp().remove(function.getOwnName());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -4476,11 +4373,7 @@ public class SameDiff extends SDBaseOps {
|
|||
else if (function instanceof BaseOp) {
|
||||
SDVariable[] ret = new SDVariable[1];
|
||||
SDVariable checkGet = getVariable(baseName);
|
||||
char ordering = 'c';
|
||||
SDVariable[] args = function.args();
|
||||
if (args != null && args.length > 0 && function.args()[0].getArr() != null) { //Args may be null or length 0 for some ops, like eye
|
||||
ordering = function.args()[0].getArr().ordering();
|
||||
}
|
||||
if (checkGet == null) {
|
||||
//Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme
|
||||
org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0);
|
||||
|
@ -4530,45 +4423,6 @@ public class SameDiff extends SDBaseOps {
|
|||
return sameDiffFunctionInstances.get(functionName);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link SDBaseOps#whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)}
|
||||
*/
|
||||
@Deprecated
|
||||
public While whileStatement(SameDiffConditional sameDiffConditional,
|
||||
SameDiffFunctionDefinition conditionBody,
|
||||
SameDiffFunctionDefinition loopBody
|
||||
, SDVariable[] inputVars) {
|
||||
return While.builder()
|
||||
.inputVars(inputVars)
|
||||
.condition(conditionBody)
|
||||
.predicate(sameDiffConditional)
|
||||
.trueBody(loopBody)
|
||||
.parent(this)
|
||||
.blockName("while-" + UUID.randomUUID().toString())
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link SDBaseOps#ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)}
|
||||
*/
|
||||
@Deprecated
|
||||
public If ifStatement(SameDiffConditional conditional,
|
||||
SameDiffFunctionDefinition conditionBody,
|
||||
SameDiffFunctionDefinition trueBody,
|
||||
SameDiffFunctionDefinition falseBody
|
||||
, SDVariable[] inputVars) {
|
||||
return If.builder()
|
||||
.conditionBody(conditionBody)
|
||||
.falseBody(falseBody)
|
||||
.trueBody(trueBody)
|
||||
.predicate(conditional)
|
||||
.inputVars(inputVars)
|
||||
.parent(this)
|
||||
.blockName("if-" + UUID.randomUUID().toString())
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new TensorArray.
|
||||
*/
|
||||
|
@ -4648,6 +4502,51 @@ public class SameDiff extends SDBaseOps {
|
|||
return execSingle(placeholders, outputs.get(0));
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #calculateGradients(Map, Collection)}
|
||||
*/
|
||||
public Map<String, INDArray> calculateGradients(Map<String, INDArray> placeholderVals, @NonNull String... variables) {
|
||||
Preconditions.checkArgument(variables.length > 0, "No variables were specified");
|
||||
return calculateGradients(placeholderVals, Arrays.asList(variables));
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate and return the gradients for the specified variables
|
||||
*
|
||||
* @param placeholderVals Placeholders. May be null
|
||||
* @param variables Names of the variables that you want the gradient arrays for
|
||||
* @return Gradients as a map, keyed by the variable name
|
||||
*/
|
||||
public Map<String, INDArray> calculateGradients(Map<String, INDArray> placeholderVals, @NonNull Collection<String> variables) {
|
||||
Preconditions.checkArgument(!variables.isEmpty(), "No variables were specified");
|
||||
if (getFunction(GRAD_FN_KEY) == null) {
|
||||
createGradFunction();
|
||||
}
|
||||
|
||||
List<String> gradVarNames = new ArrayList<>(variables.size());
|
||||
for (String s : variables) {
|
||||
Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s);
|
||||
SDVariable v = getVariable(s).getGradient();
|
||||
if (v != null) {
|
||||
//In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable
|
||||
gradVarNames.add(v.getVarName());
|
||||
}
|
||||
}
|
||||
|
||||
//Key is gradient variable name
|
||||
Map<String, INDArray> grads = getFunction(GRAD_FN_KEY).output(placeholderVals, gradVarNames);
|
||||
|
||||
Map<String, INDArray> out = new HashMap<>();
|
||||
for (String s : variables) {
|
||||
if (getVariable(s).getGradient() != null) {
|
||||
String gradVar = getVariable(s).getGradient().getVarName();
|
||||
out.put(s, grads.get(gradVar));
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create (if required) and then calculate the variable gradients (backward pass) for this graph.<br>
|
||||
* After execution, the gradient arrays can be accessed using {@code myVariable.getGradient().getArr()}<br>
|
||||
|
@ -4660,6 +4559,7 @@ public class SameDiff extends SDBaseOps {
|
|||
*
|
||||
* @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map
|
||||
*/
|
||||
@Deprecated
|
||||
public void execBackwards(Map<String, INDArray> placeholders, Operation op) {
|
||||
execBackwards(placeholders, op, null, Collections.<String>emptyList(), Collections.<Listener>emptyList());
|
||||
}
|
||||
|
@ -4669,10 +4569,12 @@ public class SameDiff extends SDBaseOps {
|
|||
* <p>
|
||||
* Uses {@link Operation#INFERENCE}.
|
||||
*/
|
||||
@Deprecated
|
||||
public void execBackwards(Map<String, INDArray> placeholders) {
|
||||
execBackwards(placeholders, Operation.INFERENCE);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
protected void execBackwards(Map<String, INDArray> placeholders, Operation op, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners) {
|
||||
if (getFunction(GRAD_FN_KEY) == null) {
|
||||
createGradFunction();
|
||||
|
@ -4709,6 +4611,7 @@ public class SameDiff extends SDBaseOps {
|
|||
/**
|
||||
* See {@link #execBackwards(Map, List, Operation)}
|
||||
*/
|
||||
@Deprecated
|
||||
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, Operation op, String... variableGradNamesList) {
|
||||
return execBackwards(placeholders, Arrays.asList(variableGradNamesList), op, null, Collections.<String>emptyList(), Collections.<Listener>emptyList());
|
||||
}
|
||||
|
@ -4718,6 +4621,7 @@ public class SameDiff extends SDBaseOps {
|
|||
* <p>
|
||||
* Uses {@link Operation#INFERENCE}.
|
||||
*/
|
||||
@Deprecated
|
||||
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, String... variableGradNamesList) {
|
||||
return execBackwards(placeholders, Operation.INFERENCE, variableGradNamesList);
|
||||
}
|
||||
|
@ -4730,6 +4634,7 @@ public class SameDiff extends SDBaseOps {
|
|||
* @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map
|
||||
* @param variableGradNamesList Names of the gradient variables to calculate
|
||||
*/
|
||||
@Deprecated
|
||||
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList, Operation operation) {
|
||||
return execBackwards(placeholders, variableGradNamesList, operation, null, Collections.<String>emptyList(), Collections.<Listener>emptyList());
|
||||
}
|
||||
|
@ -4739,10 +4644,12 @@ public class SameDiff extends SDBaseOps {
|
|||
* <p>
|
||||
* Uses {@link Operation#INFERENCE}.
|
||||
*/
|
||||
@Deprecated
|
||||
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList) {
|
||||
return execBackwards(placeholders, variableGradNamesList, Operation.INFERENCE);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
protected Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList, Operation operation,
|
||||
MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners) {
|
||||
if (getFunction(GRAD_FN_KEY) == null) {
|
||||
|
@ -5462,7 +5369,7 @@ public class SameDiff extends SDBaseOps {
|
|||
0,
|
||||
0,
|
||||
-1,
|
||||
0, 0, 0, 0, 0, 0);
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0);
|
||||
|
||||
return flatNode;
|
||||
}
|
||||
|
@ -5538,7 +5445,7 @@ public class SameDiff extends SDBaseOps {
|
|||
val idxForOps = new IdentityHashMap<DifferentialFunction, Integer>();
|
||||
List<SDVariable> allVars = variables();
|
||||
for (SDVariable variable : allVars) {
|
||||
INDArray arr = variable.getArr();
|
||||
INDArray arr = variable.getVariableType() == VariableType.ARRAY ? null : variable.getArr();
|
||||
log.trace("Exporting variable: [{}]", variable.getVarName());
|
||||
|
||||
//If variable is the output of some op - let's use the ONE index for exporting, and properly track the output
|
||||
|
@ -5582,7 +5489,26 @@ public class SameDiff extends SDBaseOps {
|
|||
shape = FlatVariable.createShapeVector(bufferBuilder, shp);
|
||||
}
|
||||
|
||||
int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape, array, -1, varType);
|
||||
int controlDeps = 0;
|
||||
int controlDepsForOp = 0;
|
||||
int controlDepsForVar = 0;
|
||||
Variable v = variables.get(varName);
|
||||
|
||||
int[] cds = FlatBuffersMapper.mapOrNull(v.getControlDeps(), bufferBuilder);
|
||||
if(cds != null)
|
||||
controlDeps = FlatVariable.createControlDepsVector(bufferBuilder, cds);
|
||||
|
||||
int[] cdsForOp = FlatBuffersMapper.mapOrNull(v.getControlDepsForOp(), bufferBuilder);
|
||||
if(cdsForOp != null)
|
||||
controlDepsForOp = FlatVariable.createControlDepForOpVector(bufferBuilder, cdsForOp);
|
||||
|
||||
int[] cdsForVar = FlatBuffersMapper.mapOrNull(v.getControlDepsForVar(), bufferBuilder);
|
||||
if(cdsForVar != null)
|
||||
controlDepsForVar = FlatVariable.createControlDepsForVarVector(bufferBuilder, cdsForVar);
|
||||
|
||||
|
||||
int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape,
|
||||
array, -1, varType, controlDeps, controlDepsForOp, controlDepsForVar);
|
||||
flatVariables.add(flatVariable);
|
||||
}
|
||||
|
||||
|
@ -5593,43 +5519,6 @@ public class SameDiff extends SDBaseOps {
|
|||
flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, variableList, reverseMap, forwardMap, framesMap, idCounter, fnId));
|
||||
}
|
||||
|
||||
// we're dumping scopes now
|
||||
for (Map.Entry<String, SameDiff> scope : sameDiffFunctionInstances.entrySet()) {
|
||||
if (scope.getKey().equalsIgnoreCase(GRAD_FN_KEY)) {
|
||||
//Skip the gradient function for export
|
||||
continue;
|
||||
}
|
||||
|
||||
flatNodes.add(asFlatNode(scope.getKey(), scope.getValue(), bufferBuilder));
|
||||
val currVarList = new ArrayList<SDVariable>(scope.getValue().variables());
|
||||
// converting all ops from node
|
||||
for (val node : scope.getValue().variables()) {
|
||||
INDArray arr = node.getArr();
|
||||
if (arr == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int name = bufferBuilder.createString(node.getVarName());
|
||||
int array = arr.toFlatArray(bufferBuilder);
|
||||
int id = IntPair.createIntPair(bufferBuilder, ++idx, 0);
|
||||
|
||||
val pair = parseVariable(node.getVarName());
|
||||
reverseMap.put(pair.getFirst(), idx);
|
||||
|
||||
log.trace("Adding [{}] as [{}]", pair.getFirst(), idx);
|
||||
|
||||
byte varType = (byte) node.getVariableType().ordinal();
|
||||
int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(arr.dataType()), 0, array, -1, varType);
|
||||
flatVariables.add(flatVariable);
|
||||
}
|
||||
|
||||
//add functions
|
||||
for (SameDiffOp op : scope.getValue().ops.values()) {
|
||||
DifferentialFunction func = op.getOp();
|
||||
flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, currVarList, reverseMap, forwardMap, framesMap, idCounter, null));
|
||||
}
|
||||
}
|
||||
|
||||
int outputsOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatOffsets));
|
||||
int variablesOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatVariables));
|
||||
int nodesOffset = FlatGraph.createNodesVector(bufferBuilder, Ints.toArray(flatNodes));
|
||||
|
@ -5958,7 +5847,7 @@ public class SameDiff extends SDBaseOps {
|
|||
vars.add(fg.variables(i));
|
||||
}
|
||||
|
||||
FlatConfiguration conf = fg.configuration();
|
||||
// FlatConfiguration conf = fg.configuration();
|
||||
|
||||
/* Reconstruct the graph
|
||||
We'll do the reconstruction manually here, rather than using sd.var(...), so that we have more control
|
||||
|
@ -5995,6 +5884,35 @@ public class SameDiff extends SDBaseOps {
|
|||
SDVariable var = new SDVariable(n, vt, sd, shape, dtype, null);
|
||||
sd.variables.put(n, Variable.builder().name(n).variable(var).build());
|
||||
sd.variableNameToShape.put(n, shape);
|
||||
Variable v2 = sd.variables.get(n);
|
||||
|
||||
//Reconstruct control dependencies
|
||||
if(v.controlDepsLength() > 0){
|
||||
int num = v.controlDepsLength();
|
||||
List<String> l = new ArrayList<>(num);
|
||||
for( int i=0; i<num; i++ ){
|
||||
l.add(v.controlDeps(i));
|
||||
}
|
||||
v2.setControlDeps(l);
|
||||
}
|
||||
if(v.controlDepForOpLength() > 0){
|
||||
int num = v.controlDepForOpLength();
|
||||
List<String> l = new ArrayList<>(num);
|
||||
for( int i=0; i<num; i++ ){
|
||||
l.add(v.controlDepForOp(i));
|
||||
}
|
||||
v2.setControlDepsForOp(l);
|
||||
}
|
||||
|
||||
if(v.controlDepsForVarLength() > 0){
|
||||
int num = v.controlDepsForVarLength();
|
||||
List<String> l = new ArrayList<>(num);
|
||||
for( int i=0; i<num; i++ ){
|
||||
l.add(v.controlDepsForVar(i));
|
||||
}
|
||||
v2.setControlDepsForVar(l);
|
||||
}
|
||||
|
||||
|
||||
|
||||
FlatArray fa = v.ndarray();
|
||||
|
@ -6063,7 +5981,37 @@ public class SameDiff extends SDBaseOps {
|
|||
}
|
||||
inputNames[i] = varIn.getVarName();
|
||||
}
|
||||
sd.ops.get(df.getOwnName()).setInputsToOp(Arrays.asList(inputNames));
|
||||
SameDiffOp op = sd.ops.get(df.getOwnName());
|
||||
op.setInputsToOp(Arrays.asList(inputNames));
|
||||
|
||||
//Reconstruct control dependencies
|
||||
if (fn.controlDepsLength() > 0) {
|
||||
int l = fn.controlDepsLength();
|
||||
List<String> list = new ArrayList<>(l);
|
||||
for( int i=0; i<l; i++ ){
|
||||
list.add(fn.controlDeps(i));
|
||||
}
|
||||
op.setControlDeps(list);
|
||||
}
|
||||
|
||||
if (fn.varControlDepsLength() > 0) {
|
||||
int l = fn.varControlDepsLength();
|
||||
List<String> list = new ArrayList<>(l);
|
||||
for( int i=0; i<l; i++ ){
|
||||
list.add(fn.varControlDeps(i));
|
||||
}
|
||||
op.setVarControlDeps(list);
|
||||
}
|
||||
|
||||
if (fn.controlDepForLength() > 0) {
|
||||
int l = fn.controlDepForLength();
|
||||
List<String> list = new ArrayList<>(l);
|
||||
for( int i=0; i<l; i++ ){
|
||||
list.add(fn.controlDepFor(i));
|
||||
}
|
||||
op.setControlDepFor(list);
|
||||
}
|
||||
|
||||
|
||||
//Record that input variables are input to this op
|
||||
for (String inName : inputNames) {
|
||||
|
@ -6072,9 +6020,7 @@ public class SameDiff extends SDBaseOps {
|
|||
v.setInputsForOp(new ArrayList<String>());
|
||||
}
|
||||
if (!v.getInputsForOp().contains(df.getOwnName())) {
|
||||
v.getInputsForOp(
|
||||
|
||||
).add(df.getOwnName());
|
||||
v.getInputsForOp().add(df.getOwnName());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6414,32 +6360,6 @@ public class SameDiff extends SDBaseOps {
|
|||
return sb.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate data types for the variables in the graph
|
||||
*/
|
||||
public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes() {
|
||||
return calculateOutputDataTypes(false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate data types for the variables in the graph
|
||||
*/
|
||||
public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes(boolean dynamicUpdate) {
|
||||
List<String> allVars = new ArrayList<>(variables.keySet());
|
||||
DataTypesSession session = new DataTypesSession(this, dynamicUpdate);
|
||||
Map<String, org.nd4j.linalg.api.buffer.DataType> phValues = new HashMap<>();
|
||||
for (Variable v : variables.values()) {
|
||||
if (v.getVariable().isPlaceHolder()) {
|
||||
org.nd4j.linalg.api.buffer.DataType dt = v.getVariable().dataType();
|
||||
Preconditions.checkNotNull(dt, "Placeholder variable %s has null datatype", v.getName());
|
||||
phValues.put(v.getName(), dt);
|
||||
}
|
||||
}
|
||||
Map<String, org.nd4j.linalg.api.buffer.DataType> out = session.output(allVars, phValues, null,
|
||||
Collections.<String>emptyList(), Collections.<Listener>emptyList(), At.defaultAt(Operation.INFERENCE));
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* For internal use only.
|
||||
* Creates a new discinct block name from baseName.
|
||||
|
@ -6470,14 +6390,14 @@ public class SameDiff extends SDBaseOps {
|
|||
* @return The imported graph
|
||||
*/
|
||||
public static SameDiff importFrozenTF(File graphFile) {
|
||||
return TFGraphMapper.getInstance().importGraph(graphFile);
|
||||
return TFGraphMapper.importGraph(graphFile);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #importFrozenTF(File)}
|
||||
*/
|
||||
public static SameDiff importFrozenTF(GraphDef graphDef) {
|
||||
return TFGraphMapper.getInstance().importGraph(graphDef);
|
||||
return TFGraphMapper.importGraph(graphDef);
|
||||
}
|
||||
|
||||
|
||||
|
@ -6487,7 +6407,7 @@ public class SameDiff extends SDBaseOps {
|
|||
* Again, the input can be text or binary.
|
||||
*/
|
||||
public static SameDiff importFrozenTF(InputStream graph) {
|
||||
return TFGraphMapper.getInstance().importGraph(graph);
|
||||
return TFGraphMapper.importGraph(graph);
|
||||
}
|
||||
|
||||
|
||||
|
@ -6511,7 +6431,7 @@ public class SameDiff extends SDBaseOps {
|
|||
int start = 1;
|
||||
|
||||
// if we already have a name like "op_2", start from trying "op_3"
|
||||
if (base.contains("_")) {
|
||||
if (base.contains("_") && base.matches(".*_\\d+")) {
|
||||
// extract number used to generate base
|
||||
Matcher num = Pattern.compile("(.*)_(\\d+)").matcher(base);
|
||||
// extract argIndex used to generate base
|
||||
|
|
|
@ -0,0 +1,444 @@
|
|||
package org.nd4j.autodiff.samediff.internal;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.function.Predicate;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Object dependency tracker.
|
||||
* <br>
|
||||
* Dependency are denoted by: X -> Y, which means "Y depends on X"<br>
|
||||
* In this implementation:<br>
|
||||
* - Dependencies may be satisfied, or not satisfied<br>
|
||||
* - The implementation tracks when the dependency for an object Y are fully satisfied. This occurs when:<br>
|
||||
* 1. No dependencies X->Y exist<br>
|
||||
* 2. All dependencies of the form X->Y have been marked as satisfied, via markSatisfied(x)<br>
|
||||
* - When a dependency is satisfied, any dependent (Ys) are checked to see if all their dependencies are satisfied<br>
|
||||
* - If a dependent has all dependencies satisfied, it is added to the "new all satisfied" queue for processing,
|
||||
* which can be accessed via {@link #hasNewAllSatisfied()}, {@link #getNewAllSatisfied()} and {@link #getNewAllSatisfiedList()}<br>
|
||||
* <br>
|
||||
* Note: Two types of dependencies exist<br>
|
||||
* 1. Standard dependencies - i.e., "Y depends on X"<br>
|
||||
* 2. "Or" dependencies - i.e., "Y depends on (A or B)".<br>
|
||||
* For Or dependencies of the form "(A or B) -> Y", Y will be marked as "all dependencies satisfied" if either A or B is marked as satisfied.
|
||||
*
|
||||
* @param <T> For a dependency X -> Y, Y has type T
|
||||
* @param <D> For a dependency X -> Y, X has type D
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class AbstractDependencyTracker<T, D> {
|
||||
@Getter
|
||||
private final Map<T, Set<D>> dependencies; //Key: the dependent. Value: all things that the key depends on
|
||||
@Getter
|
||||
private final Map<T, Set<Pair<D, D>>> orDependencies; //Key: the dependent. Value: the set of OR dependencies
|
||||
private final Map<D, Set<T>> reverseDependencies = new HashMap<>(); //Key: the dependee. Value: The set of all dependents that depend on this value
|
||||
private final Map<D, Set<T>> reverseOrDependencies = new HashMap<>();
|
||||
private final Set<D> satisfiedDependencies = new HashSet<>(); //Mark the dependency as satisfied. If not in set: assumed to not be satisfied
|
||||
|
||||
private final Set<T> allSatisfied; //Set of all dependent values (Ys) that have all dependencies satisfied
|
||||
private final Queue<T> allSatisfiedQueue = new LinkedList<>(); //Queue for *new* "all satisfied" values. Values are removed using the "new all satisfied" methods
|
||||
|
||||
|
||||
protected AbstractDependencyTracker() {
|
||||
dependencies = (Map<T, Set<D>>) newTMap();
|
||||
orDependencies = (Map<T, Set<Pair<D, D>>>) newTMap();
|
||||
allSatisfied = newTSet();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A new map where the dependents (i.e., Y in "X -> Y") are the key
|
||||
*/
|
||||
protected abstract Map<T, ?> newTMap();
|
||||
|
||||
/**
|
||||
* @return A new set where the dependents (i.e., Y in "X -> Y") are the key
|
||||
*/
|
||||
protected abstract Set<T> newTSet();
|
||||
|
||||
/**
|
||||
* @return A String representation of the dependent object
|
||||
*/
|
||||
protected abstract String toStringT(T t);
|
||||
|
||||
/**
|
||||
* @return A String representation of the dependee object
|
||||
*/
|
||||
protected abstract String toStringD(D d);
|
||||
|
||||
/**
|
||||
* Clear all internal state for the dependency tracker
|
||||
*/
|
||||
public void clear() {
|
||||
dependencies.clear();
|
||||
orDependencies.clear();
|
||||
reverseDependencies.clear();
|
||||
reverseOrDependencies.clear();
|
||||
satisfiedDependencies.clear();
|
||||
allSatisfied.clear();
|
||||
allSatisfiedQueue.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return True if no dependencies have been defined
|
||||
*/
|
||||
public boolean isEmpty() {
|
||||
return dependencies.isEmpty() && orDependencies.isEmpty() &&
|
||||
allSatisfiedQueue.isEmpty();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return True if the dependency has been marked as satisfied using {@link #markSatisfied(Object, boolean)}
|
||||
*/
|
||||
public boolean isSatisfied(@NonNull D x) {
|
||||
return satisfiedDependencies.contains(x);
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark the specified value as satisfied.
|
||||
* For example, if two dependencies have been previously added (X -> Y) and (X -> A) then after the markSatisfied(X, true)
|
||||
* call, both of these dependencies are considered satisfied.
|
||||
*
|
||||
* @param x Value to mark
|
||||
* @param satisfied Whether to mark as satisfied (true) or unsatisfied (false)
|
||||
*/
|
||||
public void markSatisfied(@NonNull D x, boolean satisfied) {
|
||||
if (satisfied) {
|
||||
boolean alreadySatisfied = satisfiedDependencies.contains(x);
|
||||
|
||||
if (!alreadySatisfied) {
|
||||
satisfiedDependencies.add(x);
|
||||
|
||||
//Check if any Y's exist that have dependencies that are all satisfied, for X -> Y
|
||||
Set<T> s = reverseDependencies.get(x);
|
||||
Set<T> s2 = reverseOrDependencies.get(x);
|
||||
|
||||
Set<T> set;
|
||||
if (s != null && s2 != null) {
|
||||
set = newTSet();
|
||||
set.addAll(s);
|
||||
set.addAll(s2);
|
||||
} else if (s != null) {
|
||||
set = s;
|
||||
} else if (s2 != null) {
|
||||
set = s2;
|
||||
} else {
|
||||
if (log.isTraceEnabled()) {
|
||||
log.trace("No values depend on: {}", toStringD(x));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (T t : set) {
|
||||
Set<D> required = dependencies.get(t);
|
||||
Set<Pair<D, D>> requiredOr = orDependencies.get(t);
|
||||
boolean allSatisfied = true;
|
||||
if (required != null) {
|
||||
for (D d : required) {
|
||||
if (!isSatisfied(d)) {
|
||||
allSatisfied = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (allSatisfied && requiredOr != null) {
|
||||
for (Pair<D, D> p : requiredOr) {
|
||||
if (!isSatisfied(p.getFirst()) && !isSatisfied(p.getSecond())) {
|
||||
allSatisfied = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (allSatisfied) {
|
||||
if (!this.allSatisfied.contains(t)) {
|
||||
this.allSatisfied.add(t);
|
||||
this.allSatisfiedQueue.add(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
satisfiedDependencies.remove(x);
|
||||
if (!allSatisfied.isEmpty()) {
|
||||
|
||||
Set<T> reverse = reverseDependencies.get(x);
|
||||
if (reverse != null) {
|
||||
for (T y : reverse) {
|
||||
if (allSatisfied.contains(y)) {
|
||||
allSatisfied.remove(y);
|
||||
allSatisfiedQueue.remove(y);
|
||||
}
|
||||
}
|
||||
}
|
||||
Set<T> orReverse = reverseOrDependencies.get(x);
|
||||
if (orReverse != null) {
|
||||
for (T y : orReverse) {
|
||||
if (allSatisfied.contains(y) && !isAllSatisfied(y)) {
|
||||
allSatisfied.remove(y);
|
||||
allSatisfiedQueue.remove(y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether any dependencies x -> y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)}
|
||||
* or {@link #addOrDependency(Object, Object, Object)}
|
||||
*
|
||||
* @param y Dependent to check
|
||||
* @return True if Y depends on any values
|
||||
*/
|
||||
public boolean hasDependency(@NonNull T y) {
|
||||
Set<D> s1 = dependencies.get(y);
|
||||
if (s1 != null && !s1.isEmpty())
|
||||
return true;
|
||||
|
||||
Set<Pair<D, D>> s2 = orDependencies.get(y);
|
||||
return s2 != null && !s2.isEmpty();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all dependencies x, for x -> y, and (x1 or x2) -> y
|
||||
*
|
||||
* @param y Dependent to get dependencies for
|
||||
* @return List of dependencies
|
||||
*/
|
||||
public DependencyList<T, D> getDependencies(@NonNull T y) {
|
||||
Set<D> s1 = dependencies.get(y);
|
||||
Set<Pair<D, D>> s2 = orDependencies.get(y);
|
||||
|
||||
List<D> l1 = (s1 == null ? null : new ArrayList<>(s1));
|
||||
List<Pair<D, D>> l2 = (s2 == null ? null : new ArrayList<>(s2));
|
||||
|
||||
return new DependencyList<>(y, l1, l2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a dependency: y depends on x, as in x -> y
|
||||
*
|
||||
* @param y The dependent
|
||||
* @param x The dependee that is required for Y
|
||||
*/
|
||||
public void addDependency(@NonNull T y, @NonNull D x) {
|
||||
if (!dependencies.containsKey(y))
|
||||
dependencies.put(y, new HashSet<D>());
|
||||
|
||||
if (!reverseDependencies.containsKey(x))
|
||||
reverseDependencies.put(x, newTSet());
|
||||
|
||||
dependencies.get(y).add(x);
|
||||
reverseDependencies.get(x).add(y);
|
||||
|
||||
checkAndUpdateIfAllSatisfied(y);
|
||||
}
|
||||
|
||||
protected void checkAndUpdateIfAllSatisfied(@NonNull T y) {
|
||||
boolean allSat = isAllSatisfied(y);
|
||||
if (allSat) {
|
||||
//Case where "x is satisfied" happened before x->y added
|
||||
if (!allSatisfied.contains(y)) {
|
||||
allSatisfied.add(y);
|
||||
allSatisfiedQueue.add(y);
|
||||
}
|
||||
} else if (allSatisfied.contains(y)) {
|
||||
if (!allSatisfiedQueue.contains(y)) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("Dependent object \"").append(toStringT(y)).append("\" was previously processed after all dependencies")
|
||||
.append(" were marked satisfied, but is now additional dependencies have been added.\n");
|
||||
DependencyList<T, D> dl = getDependencies(y);
|
||||
if (dl.getDependencies() != null) {
|
||||
sb.append("Dependencies:\n");
|
||||
for (D d : dl.getDependencies()) {
|
||||
sb.append(d).append(" - ").append(isSatisfied(d) ? "Satisfied" : "Not satisfied").append("\n");
|
||||
}
|
||||
}
|
||||
if (dl.getOrDependencies() != null) {
|
||||
sb.append("Or dependencies:\n");
|
||||
for (Pair<D, D> p : dl.getOrDependencies()) {
|
||||
sb.append(p).append(" - satisfied=(").append(isSatisfied(p.getFirst())).append(",").append(isSatisfied(p.getSecond())).append(")");
|
||||
}
|
||||
}
|
||||
throw new IllegalStateException(sb.toString());
|
||||
}
|
||||
|
||||
//Not satisfied, but is in the queue -> needs to be removed
|
||||
allSatisfied.remove(y);
|
||||
allSatisfiedQueue.remove(y);
|
||||
}
|
||||
}
|
||||
|
||||
protected boolean isAllSatisfied(@NonNull T y) {
|
||||
Set<D> set1 = dependencies.get(y);
|
||||
|
||||
boolean allSatisfied = true;
|
||||
if (set1 != null) {
|
||||
for (D d : set1) {
|
||||
allSatisfied = isSatisfied(d);
|
||||
if (!allSatisfied)
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (allSatisfied) {
|
||||
Set<Pair<D, D>> set2 = orDependencies.get(y);
|
||||
if (set2 != null) {
|
||||
for (Pair<D, D> p : set2) {
|
||||
allSatisfied = isSatisfied(p.getFirst()) || isSatisfied(p.getSecond());
|
||||
if (!allSatisfied)
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return allSatisfied;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Remove a dependency (x -> y)
|
||||
*
|
||||
* @param y The dependent that currently requires X
|
||||
* @param x The dependee that is no longer required for Y
|
||||
*/
|
||||
public void removeDependency(@NonNull T y, @NonNull D x) {
|
||||
if (!dependencies.containsKey(y) && !orDependencies.containsKey(y))
|
||||
return;
|
||||
|
||||
Set<D> s = dependencies.get(y);
|
||||
if (s != null) {
|
||||
s.remove(x);
|
||||
if (s.isEmpty())
|
||||
dependencies.remove(y);
|
||||
}
|
||||
|
||||
Set<T> s2 = reverseDependencies.get(x);
|
||||
if (s2 != null) {
|
||||
s2.remove(y);
|
||||
if (s2.isEmpty())
|
||||
reverseDependencies.remove(x);
|
||||
}
|
||||
|
||||
|
||||
Set<Pair<D, D>> s3 = orDependencies.get(y);
|
||||
if (s3 != null) {
|
||||
boolean removedReverse = false;
|
||||
Iterator<Pair<D, D>> iter = s3.iterator();
|
||||
while (iter.hasNext()) {
|
||||
Pair<D, D> p = iter.next();
|
||||
if (x.equals(p.getFirst()) || x.equals(p.getSecond())) {
|
||||
iter.remove();
|
||||
|
||||
if (!removedReverse) {
|
||||
Set<T> set1 = reverseOrDependencies.get(p.getFirst());
|
||||
Set<T> set2 = reverseOrDependencies.get(p.getSecond());
|
||||
|
||||
set1.remove(y);
|
||||
set2.remove(y);
|
||||
|
||||
if (set1.isEmpty())
|
||||
reverseOrDependencies.remove(p.getFirst());
|
||||
if (set2.isEmpty())
|
||||
reverseOrDependencies.remove(p.getSecond());
|
||||
|
||||
removedReverse = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (s3 != null && s3.isEmpty())
|
||||
orDependencies.remove(y);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -> Y<br>
|
||||
* If either x1 or x2 (or both) are marked satisfied via {@link #markSatisfied(Object, boolean)} then the
|
||||
* dependency is considered satisfied
|
||||
*
|
||||
* @param y Dependent
|
||||
* @param x1 Dependee 1
|
||||
* @param x2 Dependee 2
|
||||
*/
|
||||
public void addOrDependency(@NonNull T y, @NonNull D x1, @NonNull D x2) {
|
||||
if (!orDependencies.containsKey(y))
|
||||
orDependencies.put(y, new HashSet<Pair<D, D>>());
|
||||
|
||||
if (!reverseOrDependencies.containsKey(x1))
|
||||
reverseOrDependencies.put(x1, newTSet());
|
||||
if (!reverseOrDependencies.containsKey(x2))
|
||||
reverseOrDependencies.put(x2, newTSet());
|
||||
|
||||
orDependencies.get(y).add(new Pair<>(x1, x2));
|
||||
reverseOrDependencies.get(x1).add(y);
|
||||
reverseOrDependencies.get(x2).add(y);
|
||||
|
||||
checkAndUpdateIfAllSatisfied(y);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X->Y)
|
||||
*/
|
||||
public boolean hasNewAllSatisfied() {
|
||||
return !allSatisfiedQueue.isEmpty();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)}
|
||||
* Throws an exception if {@link #hasNewAllSatisfied()} returns false.<br>
|
||||
* Note that once a value has been retrieved from here, no new dependencies of the form (X -> Y) can be added for this value;
|
||||
* the value is considered "processed" at this point.
|
||||
*
|
||||
* @return The next new "all satisfied dependent"
|
||||
*/
|
||||
public T getNewAllSatisfied() {
|
||||
Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied");
|
||||
return allSatisfiedQueue.remove();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return As per {@link #getNewAllSatisfied()} but returns all values
|
||||
*/
|
||||
public List<T> getNewAllSatisfiedList() {
|
||||
Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied");
|
||||
List<T> ret = new ArrayList<>(allSatisfiedQueue);
|
||||
allSatisfiedQueue.clear();
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* As per {@link #getNewAllSatisfied()} but instead of returning the first dependee, it returns the first that matches
|
||||
* the provided predicate. If no value matches the predicate, null is returned
|
||||
*
|
||||
* @param predicate Predicate gor checking
|
||||
* @return The first value matching the predicate, or null if no values match the predicate
|
||||
*/
|
||||
public T getFirstNewAllSatisfiedMatching(@NonNull Predicate<T> predicate) {
|
||||
Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied");
|
||||
|
||||
T t = allSatisfiedQueue.peek();
|
||||
if (predicate.test(t)) {
|
||||
t = allSatisfiedQueue.remove();
|
||||
allSatisfied.remove(t);
|
||||
return t;
|
||||
}
|
||||
|
||||
if (allSatisfiedQueue.size() > 1) {
|
||||
Iterator<T> iter = allSatisfiedQueue.iterator();
|
||||
while (iter.hasNext()) {
|
||||
t = iter.next();
|
||||
if (predicate.test(t)) {
|
||||
iter.remove();
|
||||
allSatisfied.remove(t);
|
||||
return t;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null; //None match predicate
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -1,107 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.autodiff.samediff.internal;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.Listener;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
/**
|
||||
* Infer datatypes for all variables.
|
||||
* Optionally update the datatypes of variables as we go
|
||||
*/
|
||||
public class DataTypesSession extends AbstractSession<DataType, DataTypesSession.DataTypeCalc> {
|
||||
|
||||
protected boolean dynamicUpdate;
|
||||
|
||||
/**
|
||||
* @param sameDiff SameDiff instance
|
||||
* @param dynamicUpdate If true: Dynamically update the datatypes as we go
|
||||
*/
|
||||
public DataTypesSession(SameDiff sameDiff, boolean dynamicUpdate) {
|
||||
super(sameDiff);
|
||||
this.dynamicUpdate = dynamicUpdate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getConstantOrVariable(String variableName) {
|
||||
//Variables and constants should always have datatype available
|
||||
DataType dt = sameDiff.getVariable(variableName).dataType();
|
||||
Preconditions.checkNotNull(dt, "No datatype available for variable %s", variableName);
|
||||
return dt;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataTypeCalc getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> inputs, Set<VarId> allIterInputs, Set<String> constAndPhInputs, Map<String, DataType> placeholderValues) {
|
||||
DifferentialFunction df = sameDiff.getOpById(opName);
|
||||
List<DataType> inputDataTypes = new ArrayList<>();
|
||||
for(SDVariable v : df.args()){
|
||||
DataType dt = v.dataType();
|
||||
if(dt != null){
|
||||
inputDataTypes.add(dt);
|
||||
} else {
|
||||
String s = v.getVarName();
|
||||
for(VarId vid : inputs){
|
||||
if(vid.getVariable().equals(s)){
|
||||
DataType dt2 = nodeOutputs.get(vid);
|
||||
Preconditions.checkNotNull(dt2, "No datatype for %s", vid);
|
||||
inputDataTypes.add(dt2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return new DataTypeCalc(df, inputDataTypes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType[] getOutputs(DataTypeCalc op, FrameIter outputFrameIter, Set<VarId> inputs, Set<VarId> allIterInputs,
|
||||
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch) {
|
||||
List<DataType> outTypes = op.getFn().calculateOutputDataTypes(op.getInputTypes());
|
||||
|
||||
if(dynamicUpdate) {
|
||||
SDVariable[] fnOutputs = op.getFn().outputVariables();
|
||||
for( int i=0; i<fnOutputs.length; i++ ){
|
||||
SDVariable v = fnOutputs[i];
|
||||
DataType d = outTypes.get(i);
|
||||
if(v.dataType() != d){
|
||||
v.setDataType(d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return outTypes.toArray(new DataType[outTypes.size()]);
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
@Data
|
||||
protected static class DataTypeCalc {
|
||||
protected final DifferentialFunction fn;
|
||||
protected final List<DataType> inputTypes;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
package org.nd4j.autodiff.samediff.internal;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A list of dependencies, used in {@link AbstractDependencyTracker}
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
public class DependencyList<T, D> {
|
||||
private T dependencyFor;
|
||||
private List<D> dependencies;
|
||||
private List<Pair<D, D>> orDependencies;
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
package org.nd4j.autodiff.samediff.internal;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Dependenci tracker. See {@link AbstractDependencyTracker} for details
|
||||
*
|
||||
* @param <T> For a dependency X -> Y, Y has type T
|
||||
* @param <D> For a dependency X -> Y, X has type D
|
||||
*/
|
||||
@Slf4j
|
||||
public class DependencyTracker<T, D> extends AbstractDependencyTracker<T,D> {
|
||||
|
||||
@Override
|
||||
protected Map<T, ?> newTMap() {
|
||||
return new HashMap<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Set<T> newTSet() {
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String toStringT(T t) {
|
||||
return t.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String toStringD(D d) {
|
||||
return d.toString();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
package org.nd4j.autodiff.samediff.internal;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Object dependency tracker, using object identity (not object equality) for the Ys (of type T)<br>
|
||||
* See {@link AbstractDependencyTracker} for more details
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Slf4j
|
||||
public class IdentityDependencyTracker<T, D> extends AbstractDependencyTracker<T,D> {
|
||||
|
||||
@Override
|
||||
protected Map<T, ?> newTMap() {
|
||||
return new IdentityHashMap<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Set<T> newTSet() {
|
||||
return Collections.newSetFromMap(new IdentityHashMap<T, Boolean>());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String toStringT(T t) {
|
||||
if(t instanceof INDArray){
|
||||
INDArray i = (INDArray)t;
|
||||
return System.identityHashCode(t) + " - id=" + i.getId() + ", " + i.shapeInfoToString();
|
||||
} else {
|
||||
return System.identityHashCode(t) + " - " + t.toString();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String toStringD(D d) {
|
||||
return d.toString();
|
||||
}
|
||||
}
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.autodiff.samediff.internal;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
|
@ -24,15 +24,17 @@ import org.nd4j.autodiff.listeners.Listener;
|
|||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
import org.nd4j.autodiff.samediff.internal.memory.ArrayCloseMemoryMgr;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.*;
|
||||
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.If;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.While;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.Concat;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.Stack;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.tensorops.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
|
||||
|
@ -48,29 +50,85 @@ import org.nd4j.linalg.util.ArrayUtil;
|
|||
import java.util.*;
|
||||
|
||||
/**
|
||||
* InferenceSession: Performs inference (forward pass) on a SameDiff instance to get the outputs of the requested nodes.
|
||||
* Dynamically (in AbstractSession) calculates the required subgraph to execute to get the required outputs.
|
||||
* InferenceSession: Performs inference (forward pass) on a SameDiff instance to get the outputs of the requested nodes.<br>
|
||||
* Dynamically (in AbstractSession) calculates the required subgraph to execute to get the required outputs.<br>
|
||||
* Note that while AbstractSession handles the graph structure component, InferenceSession handles only op execution
|
||||
* and memory management<br>
|
||||
* <br>
|
||||
* For INDArray memory management - i.e., tracking and releasing memory manually, as soon as possible, to
|
||||
* minimize memory use - this is implemented using a {@link SessionMemMgr} instance (for allocations/deallocations) and
|
||||
* also {@link IdentityDependencyTracker} to track where arrays are actually used. The IdentityDependencyTracker tells
|
||||
* us when the array is no longer needed (i.e., has been "fully consumed" by all ops depending on it) accounting for the
|
||||
* fact that some operations, such as identity, enter, exit, etc, are "zero copy" for performance reasons.
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Slf4j
|
||||
public class InferenceSession extends AbstractSession<INDArray,DifferentialFunction> {
|
||||
public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||
private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\n" +
|
||||
"Alternatively, arrays defined in a workspace must be replaced after the workspace has been closed.";
|
||||
|
||||
protected static final String KERAS_TRAIN_TEST = "keras_learning_phase";
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
private SessionMemMgr mmgr; //Used for allocating and deallocating memory
|
||||
/**
|
||||
* Array use tracker: What needs to happen before the array can be closed/released?
|
||||
* As the name suggests, the INDArrays are tracked using qbject identity, not equality
|
||||
*/
|
||||
@Getter
|
||||
@Setter
|
||||
private IdentityDependencyTracker<INDArray, Dep> arrayUseTracker = new IdentityDependencyTracker<>();
|
||||
|
||||
|
||||
public InferenceSession(@NonNull SameDiff sameDiff) {
|
||||
super(sameDiff);
|
||||
|
||||
mmgr = new ArrayCloseMemoryMgr(); //TODO replace this with new (planned) array reuse memory manager
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Map<String,INDArray> preprocessPlaceholders(Map<String,INDArray> placeholders){
|
||||
//Handle casting of the input array automatically.
|
||||
//The idea here is to avoid unexpected errors if the user (for example) tries to perform inference with a double
|
||||
// array for a float placeholder
|
||||
protected Map<String, INDArray> preprocessPlaceholders(Map<String, INDArray> placeholders, At at) {
|
||||
arrayUseTracker.clear();
|
||||
|
||||
//We'll also use this method as a "pre execution" hook-in, to mark variables as something we should never deallocate
|
||||
//This occurs by never marking these "ConstantDep" and "VariableDep" instances as satisfied, so there's always
|
||||
// an unsatisfied dependency for them in the array use tracker
|
||||
//TODO we shouldn't be clearing this on every single iteration, in 99.5% of cases variables will be same as last iteration...
|
||||
for (SDVariable v : sameDiff.variables()) {
|
||||
if (v.getVariableType() == VariableType.CONSTANT) {
|
||||
arrayUseTracker.addDependency(v.getArr(), new ConstantDep(v.getVarName()));
|
||||
} else if (v.getVariableType() == VariableType.VARIABLE) {
|
||||
arrayUseTracker.addDependency(v.getArr(), new VariableDep(v.getVarName()));
|
||||
}
|
||||
}
|
||||
|
||||
//Workaround for some TF/Keras based models that require explicit train/test as a placeholder
|
||||
boolean kerasWorkaround = false;
|
||||
List<String> phs = sameDiff.inputs();
|
||||
if (phs != null && !phs.isEmpty()) {
|
||||
for (String s : phs) {
|
||||
if (s.endsWith(KERAS_TRAIN_TEST) && !placeholders.containsKey(s)) {
|
||||
// The behaviour of some Keras layers (like GRU) differs depending on whether the model is training.
|
||||
// We provide this value directly, unless the user has provided this manually
|
||||
INDArray scalar = mmgr.allocate(false, DataType.BOOL).assign(at.operation().isTrainingPhase());
|
||||
placeholders = new HashMap<>(placeholders); //Array might be singleton, or otherwise unmodifiable
|
||||
placeholders.put(s, scalar);
|
||||
kerasWorkaround = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (placeholders == null || placeholders.isEmpty()) {
|
||||
return placeholders;
|
||||
}
|
||||
|
||||
//Handle casting of the input array automatically.
|
||||
//The idea here is to avoid unexpected errors if the user (for example) tries to perform inference with a double
|
||||
// array for a float placeholder
|
||||
//TODO eventually we might have ops that support multiple input types, and hence won't need this casting
|
||||
Map<String, INDArray> out = new HashMap<>();
|
||||
for (Map.Entry<String, INDArray> e : placeholders.entrySet()) {
|
||||
Preconditions.checkState(sameDiff.hasVariable(e.getKey()), "Invalid placeholder passed for execution: " +
|
||||
|
@ -96,51 +154,195 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
|
||||
|
||||
//Second: cast the input to the required type
|
||||
//TODO For the casting case, we SHOULD actually deallocate this when we're done with it, which is usually sooner than "exec done"
|
||||
DataType dt = sameDiff.getVariable(e.getKey()).dataType();
|
||||
if(arr.dataType() != dt){
|
||||
arr = arr.castTo(dt);
|
||||
if (kerasWorkaround && e.getKey().endsWith(KERAS_TRAIN_TEST)) {
|
||||
arrayUseTracker.addDependency(arr, new ExecDoneDep());
|
||||
} else if (arr.dataType() == dt) {
|
||||
//Mark as a placeholder array in the array use tracker, so we never deallocate this array...
|
||||
arrayUseTracker.addDependency(e.getValue(), new PlaceholderDep(e.getKey()));
|
||||
} else {
|
||||
INDArray cast = mmgr.allocate(false, dt, arr.shape());
|
||||
cast.assign(arr);
|
||||
arr = cast;
|
||||
//This array CAN be deallocated once consumed, because of the cast
|
||||
//TODO we can likely close this sooner
|
||||
arrayUseTracker.addDependency(arr, new ExecDoneDep());
|
||||
}
|
||||
out.put(e.getKey(), arr);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] getOutputs(DifferentialFunction op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch) {
|
||||
protected Map<String, INDArray> postProcessOutput(Map<String, INDArray> output) {
|
||||
|
||||
//For any queued (not yet processed) ops - mark them as satisfied, so we can deallocate any arrays
|
||||
// that are waiting on them
|
||||
if (dt.hasNewAllSatisfied()) {
|
||||
List<ExecStep> execSteps = dt.getNewAllSatisfiedList();
|
||||
for (ExecStep es : execSteps) {
|
||||
if (es.getType() == ExecType.OP) {
|
||||
OpDep od = new OpDep(es.getName(), es.getFrameIter().getFrame(), es.getFrameIter().getIteration(), es.getFrameIter().getParentFrame());
|
||||
arrayUseTracker.markSatisfied(od, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Also mark "end of execution" for array dependency tracker. Mainly used for TensorArray arrays at present.
|
||||
//TODO Optimize for reduced memory for some TensorArray operations - i.e., close/deallocate earlier
|
||||
arrayUseTracker.markSatisfied(new ExecDoneDep(), true);
|
||||
if (arrayUseTracker.hasNewAllSatisfied()) {
|
||||
List<INDArray> l = arrayUseTracker.getNewAllSatisfiedList();
|
||||
for (INDArray arr : l) {
|
||||
mmgr.release(arr);
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables) {
|
||||
if (listeners != null && listeners.size() > 0) {
|
||||
SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName());
|
||||
SameDiffOp sdOp = sameDiff.getOps().get(op.getOp().getOwnName());
|
||||
for (Listener l : listeners) {
|
||||
if (l.isActive(at.operation()))
|
||||
l.preOpExecution(sameDiff, at, sdOp);
|
||||
}
|
||||
}
|
||||
|
||||
INDArray[] out = getOutputsHelper(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs);
|
||||
INDArray[] out = doExec(op.getOp(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs);
|
||||
op.getOp().clearArrays();
|
||||
|
||||
if (log.isTraceEnabled()) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append(op.getName()).append(" - ").append(outputFrameIter).append(" outputs: ");
|
||||
List<String> opOutNames = op.getOutputsOfOp();
|
||||
for (int i = 0; i < out.length; i++) {
|
||||
if (i > 0)
|
||||
sb.append(", ");
|
||||
sb.append("(").append(i).append(" - ").append(opOutNames.get(i)).append(" = ").append(
|
||||
out[i] == null ? null : out[i].getId()).append(")");
|
||||
}
|
||||
log.trace(sb.toString());
|
||||
}
|
||||
|
||||
//Call listeners, before we (maybe) deallocate input arrays
|
||||
if (listeners != null && listeners.size() > 0) {
|
||||
SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName());
|
||||
|
||||
Map<String, INDArray> namedOutsBuilder = new HashMap<>();
|
||||
|
||||
for(int i = 0 ; i < out.length ; i++)
|
||||
namedOutsBuilder.put(sdOp.outputsOfOp.get(i), out[i]);
|
||||
|
||||
Map<String, INDArray> namedOuts = Collections.unmodifiableMap(namedOutsBuilder);
|
||||
Map<String, INDArray> namedOuts = null;
|
||||
|
||||
for (Listener l : listeners) {
|
||||
if (l.isActive(at.operation())) {
|
||||
l.opExecution(sameDiff, at, batch, sdOp, out);
|
||||
//Lazily create map, only if required
|
||||
if (namedOuts == null) {
|
||||
Map<String, INDArray> namedOutsBuilder = new HashMap<>();
|
||||
|
||||
for (int i = 0; i < out.length; i++)
|
||||
namedOutsBuilder.put(op.outputsOfOp.get(i), out[i]);
|
||||
namedOuts = Collections.unmodifiableMap(namedOutsBuilder);
|
||||
}
|
||||
|
||||
|
||||
l.opExecution(sameDiff, at, batch, op, out);
|
||||
|
||||
for (String varName : namedOuts.keySet()) {
|
||||
l.activationAvailable(sameDiff, at, batch, sdOp, varName, namedOuts.get(varName));
|
||||
l.activationAvailable(sameDiff, at, batch, op, varName, namedOuts.get(varName));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//Record array uses for memory management/deallocation
|
||||
SameDiffOp o = sameDiff.getOps().get(op.getName());
|
||||
List<String> outVarNames = o.getOutputsOfOp();
|
||||
for (int i = 0; i < out.length; i++) {
|
||||
if (out[i] == null && o.getOp() instanceof Switch)
|
||||
continue; //Switch case: we only ever get one of 2 outputs, other is null (branch not executed)
|
||||
|
||||
String name = outVarNames.get(i);
|
||||
Variable v = sameDiff.getVariables().get(name);
|
||||
List<String> inputsForOps = v.getInputsForOp();
|
||||
if (inputsForOps != null) {
|
||||
for (String opName : inputsForOps) {
|
||||
//Only add dependencies if we actually need the op this feeds into, otherwise the dependency
|
||||
// will will never be marked as satisfied
|
||||
if (!subgraphOps.contains(opName))
|
||||
continue;
|
||||
|
||||
SameDiffOp forOp = sameDiff.getOps().get(opName);
|
||||
|
||||
//TODO do switch or merge need special handling also?
|
||||
if (forOp.getOp() instanceof Enter) {
|
||||
Enter e = (Enter) forOp.getOp();
|
||||
if (e.isConstant()) {
|
||||
/*
|
||||
Contant enter case: Need to keep this array around for the entire duration of the frame, including
|
||||
any nested frames, and all iterations.
|
||||
Unfortunately, we don't know exactly when we're done with a frame for good
|
||||
This isn't a great solution, but other possibilities (frame close, trying to detect all exit ops,
|
||||
detecting return to parent frame, etc all fail in certain circumstances, such as due to control dependencies
|
||||
on variables).
|
||||
*/
|
||||
Dep d = new ExecDoneDep();
|
||||
arrayUseTracker.addDependency(out[i], d);
|
||||
} else {
|
||||
Dep d = new OpDep(opName, e.getFrameName(), 0, outputFrameIter);
|
||||
arrayUseTracker.addDependency(out[i], d); //Op defined by "d" needs to be executed before specified array can be closed
|
||||
}
|
||||
} else if (forOp.getOp() instanceof NextIteration) {
|
||||
//The array is needed by the NEXT iteration op, not the current one
|
||||
Dep d = new OpDep(opName, outputFrameIter.getFrame(), outputFrameIter.getIteration() + 1, outputFrameIter.getParentFrame());
|
||||
arrayUseTracker.addDependency(out[i], d);
|
||||
} else if (forOp.getOp() instanceof Exit) {
|
||||
//The array is needed at the EXIT frame (i.e., parent frame), not the inner/just executed one
|
||||
FrameIter fi = outputFrameIter.getParentFrame();
|
||||
Dep d = new OpDep(opName, fi.getFrame(), fi.getIteration(), fi.getParentFrame());
|
||||
arrayUseTracker.addDependency(out[i], d); //Op defined by "d" needs to be executed before specified array can be closed
|
||||
} else {
|
||||
//All other ops...
|
||||
Dep d = new OpDep(opName, outputFrameIter.getFrame(), outputFrameIter.getIteration(), outputFrameIter.getParentFrame());
|
||||
arrayUseTracker.addDependency(out[i], d); //Op defined by "d" needs to be executed before specified array can be closed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (OUTER_FRAME.equals(outputFrameIter.getFrame()) && allReqVariables.contains(name)) {
|
||||
//This variable is an output, record that in the array use tracker, so we don't deallocate it
|
||||
arrayUseTracker.addDependency(out[i], new ReqOutputDep(name));
|
||||
} else if ((inputsForOps == null || inputsForOps.isEmpty()) && !arrayUseTracker.hasDependency(out[i])) {
|
||||
//This particular array is not actually needed anywhere, so we can deallocate in immediately
|
||||
//Possibly only a control dependency, or only one of the outputs of a multi-output op is used
|
||||
if (log.isTraceEnabled()) {
|
||||
log.trace("Found array id {} (output of {}) not required anywhere, deallocating", out[i].getId(), o.getName());
|
||||
}
|
||||
mmgr.release(out[i]);
|
||||
}
|
||||
}
|
||||
|
||||
//Mark current op dependency as satisfied...
|
||||
Dep d = new OpDep(op.getName(), outputFrameIter.getFrame(), outputFrameIter.getIteration(), outputFrameIter.getParentFrame());
|
||||
arrayUseTracker.markSatisfied(d, true);
|
||||
|
||||
|
||||
//Close any no longer required arrays
|
||||
if (arrayUseTracker.hasNewAllSatisfied()) {
|
||||
List<INDArray> canClose = arrayUseTracker.getNewAllSatisfiedList();
|
||||
for (INDArray arr : canClose) {
|
||||
if (log.isTraceEnabled()) {
|
||||
log.trace("Closing array... id={}, {}", arr.getId(), arr.shapeInfoToString());
|
||||
}
|
||||
mmgr.release(arr);
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
public INDArray[] getOutputsHelper(DifferentialFunction op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||
public INDArray[] doExec(DifferentialFunction op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||
Set<String> constAndPhInputs) {
|
||||
|
||||
int totalInputs = (opInputs == null ? 0 : opInputs.size()) + (constAndPhInputs == null ? 0 : constAndPhInputs.size())
|
||||
|
@ -151,17 +353,18 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
if (op instanceof Identity) {
|
||||
Identity i = (Identity) op;
|
||||
String[] argNames = i.argNames();
|
||||
Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in identity op, got %s", argNames);
|
||||
VarId vid = newVarId(argNames[0], outputFrameIter);
|
||||
return new INDArray[]{nodeOutputs.get(vid)};
|
||||
Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in identity op, got %s", (Object) argNames);
|
||||
VarId vid = outputFrameIter.toVarId(argNames[0]);
|
||||
|
||||
INDArray orig = nodeOutputs.get(vid);
|
||||
return new INDArray[]{orig};
|
||||
} else if (op instanceof Switch) {
|
||||
Switch s = (Switch) op;
|
||||
String[] argNames = s.argNames(); //Order: input, boolean array
|
||||
VarId vidPredicate = newVarId(argNames[1], outputFrameIter);
|
||||
VarId vidPredicate = outputFrameIter.toVarId(argNames[1]);
|
||||
INDArray predicate = this.nodeOutputs.get(vidPredicate);
|
||||
Preconditions.checkState(predicate.isScalar() && predicate.dataType() == DataType.BOOL, "Expected boolean predicate: got %ndSInfo", predicate);
|
||||
VarId vid = newVarId(argNames[0], outputFrameIter);
|
||||
VarId vid = outputFrameIter.toVarId(argNames[0]);
|
||||
if (predicate.getDouble(0) == 0.0) {
|
||||
return new INDArray[]{this.nodeOutputs.get(vid), null};
|
||||
} else {
|
||||
|
@ -171,7 +374,7 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
//Enter op: forwards input to specified execution frame
|
||||
Enter e = (Enter) op;
|
||||
String[] input = e.argNames();
|
||||
Preconditions.checkState(input.length == 1, "Expected only 1 arg name for enter op: got %s", input);
|
||||
Preconditions.checkState(input.length == 1, "Expected only 1 arg name for enter op: got %s", (Object) input);
|
||||
Preconditions.checkState(totalInputs == 1, "Expected exactly 1 op input for Enter op \"%s\", got %s+%s", e.getOwnName(), opInputs, constAndPhInputs);
|
||||
|
||||
VarId inputVarId;
|
||||
|
@ -211,23 +414,23 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
" be 1 larger than the input iteration. Input: %s, output %s", in, outputFrameIter);
|
||||
|
||||
INDArray inArr = this.nodeOutputs.get(in);
|
||||
if (inArr == null) {
|
||||
Preconditions.throwStateEx("Could not find array for NextIteration operation %s with output %s (frame=%s, iteration=%s)",
|
||||
op.getOwnName(), sameDiff.getOps().get(op.getOwnName()).getOutputsOfOp().get(0), outputFrameIter.getFrame(), outputFrameIter.getIteration());
|
||||
}
|
||||
return new INDArray[]{inArr};
|
||||
} else if(op instanceof If) {
|
||||
If i = (If) op;
|
||||
String[] argNames = i.argNames(); //Order should be: [boolean], true, false
|
||||
|
||||
|
||||
throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName());
|
||||
} else if (op instanceof Merge) {
|
||||
//Merge avairable for forward pass when any of its inputs are available. When multiple are available, behaviour
|
||||
//Merge available for forward pass when any of its inputs are available. When multiple are available, behaviour
|
||||
// is undefined
|
||||
Merge m = (Merge) op;
|
||||
String[] in = sameDiff.getInputsForOp(op);
|
||||
for (String s : in) {
|
||||
VarId vid = newVarId(s, outputFrameIter);
|
||||
VarId vid = outputFrameIter.toVarId(s);
|
||||
if (nodeOutputs.containsKey(vid)) {
|
||||
log.trace("Returning input \"{}\" for merge node \"{}\"", m.getOwnName(), s);
|
||||
return new INDArray[]{nodeOutputs.get(vid)};
|
||||
INDArray arr = nodeOutputs.get(vid);
|
||||
Preconditions.checkState(arr != null, "Could not find output array for %s", vid);
|
||||
return new INDArray[]{arr};
|
||||
}
|
||||
}
|
||||
throw new IllegalStateException("Merge node " + m.getOwnName() + " has no available inputs (all inputs: " + Arrays.toString(in) +
|
||||
|
@ -236,25 +439,59 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
//LoopCond just forwards scalar boolean to output
|
||||
LoopCond lc = (LoopCond) op;
|
||||
String[] argNames = lc.argNames();
|
||||
Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in LoopCond op, got %s", argNames);
|
||||
VarId vid = newVarId(argNames[0], outputFrameIter);
|
||||
Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in LoopCond op, got %s", (Object) argNames);
|
||||
VarId vid = outputFrameIter.toVarId(argNames[0]);
|
||||
INDArray arr = nodeOutputs.get(vid);
|
||||
Preconditions.checkNotNull(arr, "Input to LoopCond op must not be null");
|
||||
Preconditions.checkState(arr.isScalar() && arr.dataType() == DataType.BOOL, "LoopCond input must be a scalar boolean, got %ndShape");
|
||||
return new INDArray[]{arr};
|
||||
} else if (op instanceof BaseTensorOp) {
|
||||
//TensorOps - special cases...
|
||||
return getOutputsHelperTensorArrayOps(op, outputFrameIter, opInputs, allIterInputs);
|
||||
} else if (op instanceof GradientBackwardsMarker) {
|
||||
INDArray out = mmgr.allocate(false, DataType.FLOAT).assign(1.0f);
|
||||
return new INDArray[]{out};
|
||||
} else if (op instanceof ExternalErrorsFunction) {
|
||||
ExternalErrorsFunction fn = (ExternalErrorsFunction) op;
|
||||
String n = fn.getGradPlaceholderName();
|
||||
INDArray arr = nodeOutputs.get(new VarId(n, OUTER_FRAME, 0, null));
|
||||
Preconditions.checkState(arr != null, "Could not find external errors placeholder array: %s", arr);
|
||||
INDArray out = mmgr.allocate(false, arr.dataType(), arr.shape());
|
||||
out.assign(arr);
|
||||
return new INDArray[]{out};
|
||||
} else if (op instanceof CustomOp) {
|
||||
CustomOp c = (CustomOp) op;
|
||||
Nd4j.exec(c);
|
||||
return c.outputArguments();
|
||||
} else if (op instanceof Op) {
|
||||
Op o = (Op) op;
|
||||
Nd4j.exec(o);
|
||||
return new INDArray[]{o.z()};
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Forward pass for TensorArray ops
|
||||
*/
|
||||
public INDArray[] getOutputsHelperTensorArrayOps(DifferentialFunction op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs) {
|
||||
/*
|
||||
TODO: TensorArray memory management note: For now, we'll close any INDArrays stored in the TensorArray at the end of
|
||||
graph execution. This uses more memory than necessary for an earlier close strategy, but simplifies memory management.
|
||||
This should be revisited and optimized later
|
||||
*/
|
||||
|
||||
if (op instanceof TensorArray) {
|
||||
//Create a TensorArray
|
||||
VarId vid = newVarId(op.outputVariable().getVarName(), outputFrameIter);
|
||||
VarId vid = outputFrameIter.toVarId(op.outputVariable().getVarName());
|
||||
Preconditions.checkState(!tensorArrays.containsKey(vid), "TensorArray already exists for %s when executing TensorArrayV3", vid);
|
||||
tensorArrays.put(vid, new ArrayList<INDArray>());
|
||||
|
||||
// Note that TensorArray has 2 outputs - a 'dummy' SDVariable that represents it, and a second output (return a scalar 0.0)
|
||||
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
//TODO Proper workspace support will be added to SameDiff later
|
||||
return new INDArray[]{Nd4j.scalar(true), Nd4j.scalar(0.0f)};
|
||||
}
|
||||
INDArray dummy = mmgr.allocate(false, DataType.BOOL).assign(true);
|
||||
INDArray scalar = mmgr.allocate(false, DataType.FLOAT).assign(0.0);
|
||||
return new INDArray[]{dummy, scalar};
|
||||
} else if (op instanceof TensorArrayRead) {
|
||||
//Do lookup and return
|
||||
//Input 0 is the TensorArray (or dummy variable that represents it). Sometimes (for import) this can be like (TensorArray -> Enter -> TensorArrayRead)
|
||||
|
@ -278,7 +515,7 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
//Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayRead
|
||||
//TODO also TensorArrayWrite, scatter, etc??
|
||||
inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg();
|
||||
v = newVarId(inTensorArray.getVarName(), v.getParentFrame());
|
||||
v = v.getParentFrame().toVarId(inTensorArray.getVarName());
|
||||
}
|
||||
|
||||
List<INDArray> list = getTensorArrays().get(v);
|
||||
|
@ -289,7 +526,6 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
return new INDArray[]{out};
|
||||
} else if (op instanceof TensorArrayWrite) {
|
||||
//TensorArrayWrite - also has a scalar 0.0 that it returns...
|
||||
|
||||
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
|
||||
//Work out the varid (frame/iteration) of the tensor array:
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false));
|
||||
|
@ -303,7 +539,7 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
//Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayWrite
|
||||
//TODO also TensorArrayScatter, etc??
|
||||
inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg();
|
||||
tArr = newVarId(inTensorArray.getVarName(), tArr.getParentFrame());
|
||||
tArr = tArr.getParentFrame().toVarId(inTensorArray.getVarName());
|
||||
}
|
||||
|
||||
//Input 0 is the TensorArray (or dummy variable that represents it) - but sometimes Enter, in TensorArray -> Enter -> TensorARrayRead
|
||||
|
@ -330,11 +566,13 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
}
|
||||
l.set(idx, arr);
|
||||
|
||||
//Add a dependency
|
||||
Dep d = new ExecDoneDep();
|
||||
arrayUseTracker.addDependency(arr, d);
|
||||
|
||||
//Return dummy array
|
||||
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
//TODO Proper workspace support will be added to SameDiff later
|
||||
return new INDArray[]{Nd4j.scalar(0.0f)};
|
||||
}
|
||||
INDArray scalar = mmgr.allocate(false, DataType.FLOAT).assign(0.0);
|
||||
return new INDArray[]{scalar};
|
||||
} else if (op instanceof TensorArraySize) {
|
||||
//Index 0 is the TensorArray (or dummy variable that represents it)
|
||||
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
|
||||
|
@ -345,10 +583,9 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
}
|
||||
List<INDArray> l = tensorArrays.get(tArr);
|
||||
Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr);
|
||||
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
//TODO Proper workspace support will be added to SameDiff later
|
||||
return new INDArray[]{Nd4j.scalar(DataType.INT, l.size())};
|
||||
}
|
||||
|
||||
INDArray scalar = mmgr.allocate(false, DataType.INT).assign(l.size());
|
||||
return new INDArray[]{scalar};
|
||||
} else if (op instanceof TensorArrayConcat) {
|
||||
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false));
|
||||
|
@ -356,12 +593,13 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false);
|
||||
}
|
||||
List<INDArray> l = tensorArrays.get(tArr);
|
||||
//TODO - empty checks. But is size 0 OK?
|
||||
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
//TODO Proper workspace support will be added to SameDiff later
|
||||
INDArray concat = Nd4j.concat(0, l.toArray(new INDArray[l.size()]));
|
||||
return new INDArray[]{concat};
|
||||
}
|
||||
|
||||
Concat c = new Concat(0, l.toArray(new INDArray[0]));
|
||||
List<LongShapeDescriptor> shape = c.calculateOutputShape();
|
||||
INDArray out = mmgr.allocate(false, shape.get(0));
|
||||
c.setOutputArgument(0, out);
|
||||
Nd4j.exec(c);
|
||||
return new INDArray[]{out};
|
||||
} else if (op instanceof TensorArrayGather) {
|
||||
//Input 0: the TensorArray
|
||||
//Input 1: the indices (1d integer vector)
|
||||
|
@ -383,7 +621,7 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
int[] idxArrInt = idxArr.toIntVector();
|
||||
|
||||
//Edge case: -1 means "all"
|
||||
ArrayList<INDArray> newList = new ArrayList<>();
|
||||
List<INDArray> newList = new ArrayList<>();
|
||||
if (idxArrInt.length == 1 && idxArrInt[0] == -1) {
|
||||
newList.addAll(l);
|
||||
} else {
|
||||
|
@ -392,11 +630,13 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
newList.add(l.get(id));
|
||||
}
|
||||
}
|
||||
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
//TODO Proper workspace support will be added to SameDiff later
|
||||
INDArray out = Nd4j.pile(newList);
|
||||
|
||||
Stack s = new Stack(newList.toArray(new INDArray[0]), null, 0);
|
||||
List<LongShapeDescriptor> shape = s.calculateOutputShape();
|
||||
INDArray out = mmgr.allocate(false, shape.get(0));
|
||||
s.setOutputArgument(0, out);
|
||||
Nd4j.exec(s);
|
||||
return new INDArray[]{out};
|
||||
}
|
||||
} else if (op instanceof TensorArrayScatter) {
|
||||
//Scatter values from a rank (N+1)d tensor into specific indices of the TensorArray
|
||||
//Input 0: the TensorArray
|
||||
|
@ -435,23 +675,20 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
INDArrayIndex[] idx = ArrayUtil.nTimes(valuesArr.rank(), NDArrayIndex.all(), INDArrayIndex.class);
|
||||
for (int i = 0; i < idxs.length; i++) {
|
||||
idx[0] = NDArrayIndex.point(i);
|
||||
INDArray get = valuesArr.get(idx).dup();
|
||||
INDArray get = mmgr.dup(valuesArr.get(idx));
|
||||
int outIdx = idxs[i];
|
||||
if(valuesArr.rank() == 2 && get.rank() == 2){
|
||||
//Workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/7092
|
||||
get = get.reshape(get.length());
|
||||
}
|
||||
if (valuesArr.rank() == 1 && get.rank() > 0) {
|
||||
get = get.reshape(new long[0]);
|
||||
get = get.reshape();
|
||||
}
|
||||
l.set(outIdx, get);
|
||||
|
||||
//Add dependency for values array until end of execution
|
||||
arrayUseTracker.addDependency(get, new ExecDoneDep());
|
||||
}
|
||||
|
||||
//Return dummy array
|
||||
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
//TODO Proper workspace support will be added to SameDiff later
|
||||
return new INDArray[]{Nd4j.scalar(0.0f)};
|
||||
}
|
||||
INDArray scalar = mmgr.allocate(false, DataType.FLOAT).assign(0.0);
|
||||
return new INDArray[]{scalar};
|
||||
} else if (op instanceof TensorArraySplit) {
|
||||
//Split values from a rank (N+1)d tensor into sequential indices of the TensorArray
|
||||
//For example, orig=[8,2] sizearray with split (4,4) means TensorArray[0] = orig[0:4,:] and TensorArray[1] = orig[4:8,:]
|
||||
|
@ -486,33 +723,23 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
int soFar = 0;
|
||||
for (int i = 0; i < sizes.length; i++) {
|
||||
idx[0] = NDArrayIndex.interval(soFar, soFar + sizes[i]);
|
||||
INDArray sub = splitArr.get(idx).dup();
|
||||
INDArray sub = mmgr.dup(splitArr.get(idx));
|
||||
l.set(i, sub);
|
||||
soFar += sizes[i];
|
||||
|
||||
//Add dependency for values array until end of execution
|
||||
arrayUseTracker.addDependency(sub, new ExecDoneDep());
|
||||
}
|
||||
|
||||
//Return dummy array
|
||||
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
//TODO Proper workspace support will be added to SameDiff later
|
||||
return new INDArray[]{Nd4j.scalar(0.0f)};
|
||||
}
|
||||
INDArray scalar = mmgr.allocate(false, DataType.FLOAT).assign(0.0);
|
||||
return new INDArray[]{scalar};
|
||||
} else {
|
||||
throw new IllegalStateException("Execution support not yet implemented for: " + op.getClass().getName());
|
||||
}
|
||||
} else if(op instanceof GradientBackwardsMarker){
|
||||
return new INDArray[]{Nd4j.scalar(1.0f)};
|
||||
} else if(op instanceof CustomOp){
|
||||
CustomOp c = (CustomOp)op;
|
||||
Nd4j.getExecutioner().exec(c);
|
||||
return c.outputArguments();
|
||||
} else if(op instanceof Op) {
|
||||
Op o = (Op) op;
|
||||
Nd4j.getExecutioner().exec(o);
|
||||
return new INDArray[]{o.z()};
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public INDArray getConstantOrVariable(String variableName) {
|
||||
SDVariable v = sameDiff.getVariable(variableName);
|
||||
|
@ -522,21 +749,19 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
}
|
||||
|
||||
@Override
|
||||
public DifferentialFunction getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||
Set<String> constAndPhInputs, Map<String,INDArray> placeholderValues) {
|
||||
public SameDiffOp getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||
Set<String> constAndPhInputs, Map<String, INDArray> placeholderValues, Set<String> allReqVariables) {
|
||||
SameDiffOp sdo = sameDiff.getOps().get(opName);
|
||||
DifferentialFunction df = sdo.getOp();
|
||||
|
||||
DifferentialFunction df = sameDiff.getOpById(opName);
|
||||
//TODO Switch to OpContext - and make sure executing like that is thread safe (i.e., array fields in ops are not used etc)
|
||||
|
||||
//TODO We should clone these ops - probably - as we don't want them shared between threads/sessions!
|
||||
//But let's only clone them *once* and cache in inference session - not on every exec
|
||||
|
||||
Preconditions.checkNotNull(df, "No differential function fond with name %s", opName);
|
||||
Preconditions.checkNotNull(df, "No differential function found with name \"%s\"", opName);
|
||||
|
||||
if (df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration ||
|
||||
df instanceof Merge || df instanceof Switch || df instanceof If || df instanceof While ||
|
||||
df instanceof BaseTensorOp){
|
||||
df instanceof Merge || df instanceof Switch || df instanceof BaseTensorOp) {
|
||||
//Control dependencies and tensor ops (like TensorArray, TensorArrayRead etc) don't need inputs set, execution is a special case
|
||||
return df;
|
||||
return sdo;
|
||||
}
|
||||
|
||||
//Infer the args based on the inputs (variable + frame + iteration)
|
||||
|
@ -546,72 +771,16 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
int numNonConstInsAllIters = (allIterInputs == null ? 0 : allIterInputs.size());
|
||||
int numConstPhIns = (constAndPhInputs == null ? 0 : constAndPhInputs.size());
|
||||
|
||||
Set<String> constEnterInputs = null;
|
||||
if (numArgs != (numNonConstIns + numConstPhIns + numNonConstInsAllIters)) {
|
||||
boolean anyConstEnterInputs = false;
|
||||
SDVariable[] args = df.args();
|
||||
for(SDVariable v : args){
|
||||
Variable var = sameDiff.getVariables().get(v.getVarName());
|
||||
//Nested enter case:
|
||||
DifferentialFunction inputVarFn = (var.getOutputOfOp() == null ? null : sameDiff.getOps().get(var.getOutputOfOp()).getOp());
|
||||
if(inputVarFn instanceof Enter && ((Enter)inputVarFn).isConstant()){
|
||||
anyConstEnterInputs = true;
|
||||
if(constEnterInputs == null)
|
||||
constEnterInputs = new HashSet<>();
|
||||
constEnterInputs.add(v.getVarName());
|
||||
}
|
||||
}
|
||||
|
||||
int constEnterInputCount = 0;
|
||||
if(anyConstEnterInputs){
|
||||
/*
|
||||
2019/01/26: AB
|
||||
Resolve nested enter inputs (constants 2+ enters in)
|
||||
Why this hack is necessary: consider the following (sub) graph: constX -> Enter(a) -> Enter(b) -> opY
|
||||
On iterations (a=0, b=0) all is well, opY gets triggered as normal.
|
||||
On iterations (a>0, b=*) the "opY is available for exec" won't be triggered.
|
||||
This is because Enter(a) is only executed once, on iteration 0 of the outer loop.
|
||||
Consequently, Enter(b) is not triggered as available on iteration 1+.
|
||||
When we do the lookup for the actual array to use for op execution (i.e., get inputs for opY(a=1,b=0))
|
||||
it won't be found.
|
||||
This is a bit of an ugly hack, though I've yet to find a cleaner solution.
|
||||
It should only be required with the combination of: constants, 2 levels of enters, and more than 1 iteration in each loop.
|
||||
*/
|
||||
|
||||
//For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should
|
||||
// be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0))
|
||||
for(String s : constEnterInputs){
|
||||
//First: check if this has already been provided
|
||||
if(constAndPhInputs != null && constAndPhInputs.contains(s)){
|
||||
//already resolved/provided
|
||||
continue;
|
||||
}
|
||||
boolean found = false;
|
||||
if(allIterInputs != null) {
|
||||
for (VarId vid : allIterInputs) {
|
||||
if (s.equals(vid.getVariable())) {
|
||||
//Already resolved/provided
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(found)
|
||||
continue;
|
||||
|
||||
constEnterInputCount++;
|
||||
}
|
||||
}
|
||||
|
||||
if (numArgs > 1) {
|
||||
//Might be due to repeated inputs
|
||||
Set<String> uniqueArgNames = new HashSet<>();
|
||||
Collections.addAll(uniqueArgNames, argNames);
|
||||
Preconditions.checkState(uniqueArgNames.size() == (numNonConstIns + numConstPhIns + numNonConstInsAllIters + constEnterInputCount),
|
||||
Preconditions.checkState(uniqueArgNames.size() == (numNonConstIns + numConstPhIns + numNonConstInsAllIters),
|
||||
"Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", df.getClass().getSimpleName(),
|
||||
opName, uniqueArgNames, opInputs, constAndPhInputs);
|
||||
} else {
|
||||
Preconditions.checkState(numArgs == (numNonConstIns + numConstPhIns + constEnterInputCount),
|
||||
Preconditions.checkState(numArgs == (numNonConstIns + numConstPhIns),
|
||||
"Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", df.getClass().getSimpleName(),
|
||||
opName, argNames, opInputs, constAndPhInputs);
|
||||
}
|
||||
|
@ -625,44 +794,18 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
SDVariable v = sameDiff.getVariable(s);
|
||||
if (v.isConstant()) {
|
||||
args[i] = v.getArr();
|
||||
} else if (v.getVariableType() == VariableType.VARIABLE) {
|
||||
args[i] = v.getArr();
|
||||
} else if (v.isPlaceHolder()) {
|
||||
Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(s), "No array provided for placeholder %s", s);
|
||||
args[i] = placeholderValues.get(s);
|
||||
} else if(constEnterInputs != null && constEnterInputs.contains(s)){
|
||||
//For enter nodes that are constants, we want iteration 0 in all frames in the heirarchy
|
||||
//For example, const -> Enter(a) -> Enter(b) -> op; in this case, the input to Op (at any frame/iteration) should should
|
||||
// be the constant value - which is recorded as (frame="a",iter=0,parent=(frame="b",iter=0))
|
||||
VarId vid = newVarId(s, frameIter.clone());
|
||||
vid.setIteration(0);
|
||||
FrameIter toZero = vid.getParentFrame();
|
||||
while(toZero != null){
|
||||
toZero.setIteration(0);
|
||||
toZero = toZero.getParentFrame();
|
||||
}
|
||||
INDArray arr = this.nodeOutputs.get(vid);
|
||||
args[i] = arr;
|
||||
} else {
|
||||
if(opInputs != null) {
|
||||
for (VarId vid : opInputs) {
|
||||
if (vid.getVariable().equals(s)) {
|
||||
args[i] = this.nodeOutputs.get(vid);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(args[i] == null && allIterInputs != null){
|
||||
for(VarId vid : allIterInputs){
|
||||
if(vid.getVariable().equals(s)){
|
||||
args[i] = this.nodeOutputs.get(vid);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
VarId vid = lookup(s, opInputs, allIterInputs, true);
|
||||
args[i] = nodeOutputs.get(vid);
|
||||
}
|
||||
Preconditions.checkNotNull(args[i], "Could not parameterize op %s: array %s (variable %s) is null", opName, i, v.getVarName());
|
||||
i++;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//Set the op inputs and output arguments
|
||||
|
@ -677,7 +820,12 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
customOp.setInputArguments(args);
|
||||
}
|
||||
|
||||
df.resolvePropertiesFromSameDiffBeforeExecution();
|
||||
if (df instanceof Identity) {
|
||||
//We don't need to allocate an output array for Identity, we pass through the input array without copying
|
||||
return sdo;
|
||||
}
|
||||
|
||||
df.resolvePropertiesFromSameDiffBeforeExecution(); //TODO This is to be removed
|
||||
List<LongShapeDescriptor> outShape = customOp.calculateOutputShape();
|
||||
Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName());
|
||||
String[] outNames = df.outputVariablesNames();
|
||||
|
@ -695,12 +843,9 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
reqShape = reqShape.asDataType(dt);
|
||||
}
|
||||
|
||||
if(currOutput == null || !currOutput.shapeDescriptor().equals(reqShape) || currOutput.isEmpty() != reqShape.isEmpty() || isLoop){
|
||||
INDArray out;
|
||||
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
//TODO Proper workspace support will be added to SameDiff later
|
||||
out = Nd4j.create(reqShape, false);
|
||||
}
|
||||
if (currOutput == null || currOutput.wasClosed() || !currOutput.shapeDescriptor().equals(reqShape) || currOutput.isEmpty() != reqShape.isEmpty() || isLoop) {
|
||||
boolean isOutput = allReqVariables.contains(outNames[i]);
|
||||
INDArray out = mmgr.allocate(isOutput, reqShape);
|
||||
customOp.setOutputArgument(i, out);
|
||||
}
|
||||
}
|
||||
|
@ -753,30 +898,30 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
INDArray z = op.z();
|
||||
if (z == null || !op.x().equalShapes(z) || isLoop) {
|
||||
//Note: edge case: [x,y].sum(empty) = [x,y] for TF import compatibility.
|
||||
op.setZ(op.x().ulike());
|
||||
z = mmgr.allocate(false, op.x().dataType(), op.x().shape());
|
||||
op.setZ(z);
|
||||
}
|
||||
} else {
|
||||
List<LongShapeDescriptor> outputShape = ((BaseOp) op).calculateOutputShape();
|
||||
Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass());
|
||||
INDArray z = op.z();
|
||||
if (z == null || !outputShape.get(0).equals(z.shapeDescriptor()) || isLoop) {
|
||||
if (z == null || z.wasClosed() || !outputShape.get(0).equals(z.shapeDescriptor()) || isLoop) {
|
||||
if (log.isTraceEnabled()) {
|
||||
log.trace("Existing op result (z) array shape for op {} was {}, allocating new array of shape {}",
|
||||
op.getClass().getSimpleName(), (z == null ? null : Arrays.toString(z.shape())), outputShape.get(0).toString());
|
||||
}
|
||||
|
||||
LongShapeDescriptor lsd = outputShape.get(0);
|
||||
try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
//TODO Proper workspace support will be added to SameDiff later
|
||||
z = Nd4j.create(lsd, false);
|
||||
}
|
||||
|
||||
boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]);
|
||||
z = mmgr.allocate(isOutput, lsd);
|
||||
op.setZ(z);
|
||||
}
|
||||
}
|
||||
df.resolvePropertiesFromSameDiffBeforeExecution();
|
||||
}
|
||||
|
||||
return df;
|
||||
return sdo;
|
||||
}
|
||||
|
||||
|
||||
|
@ -785,15 +930,69 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
if (sdv.getVariableType() == VariableType.CONSTANT || sdv.getVariableType() == VariableType.VARIABLE) {
|
||||
return getConstantOrVariable(n);
|
||||
} else {
|
||||
VarId inVarId = null;
|
||||
if(opInputs != null){
|
||||
inVarId = lookup(n, opInputs, false);
|
||||
}
|
||||
if(inVarId == null && allIterInputs != null && !allIterInputs.isEmpty()){
|
||||
inVarId = lookup(n, allIterInputs, false);
|
||||
}
|
||||
VarId inVarId = lookup(n, opInputs, allIterInputs, false);
|
||||
Preconditions.checkState(inVarId != null, "Could not find array for variable %s", sdv.getVarName());
|
||||
return nodeOutputs.get(inVarId);
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
public abstract static class Dep {
|
||||
protected String frame;
|
||||
protected FrameIter parentFrame;
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public static class OpDep extends Dep {
|
||||
protected String opName;
|
||||
protected int iter;
|
||||
|
||||
protected OpDep(@NonNull String opName, @NonNull String frame, int iter, FrameIter parentFrame) {
|
||||
this.opName = opName;
|
||||
this.frame = frame;
|
||||
this.iter = iter;
|
||||
this.parentFrame = parentFrame;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "OpDep(" + opName + ",frame=" + frame + ",iter=" + iter + (parentFrame == null ? "" : ",parent=" + parentFrame) + ")";
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@AllArgsConstructor
|
||||
protected static class PlaceholderDep extends Dep {
|
||||
protected String phName;
|
||||
}
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@AllArgsConstructor
|
||||
protected static class VariableDep extends Dep {
|
||||
protected String varName;
|
||||
}
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@AllArgsConstructor
|
||||
protected static class ConstantDep extends Dep {
|
||||
protected String constName;
|
||||
}
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@AllArgsConstructor
|
||||
protected static class ReqOutputDep extends Dep {
|
||||
protected String outputName;
|
||||
}
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@NoArgsConstructor
|
||||
protected static class ExecDoneDep extends Dep {
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,4 +34,6 @@ public class SameDiffOp {
|
|||
protected List<String> inputsToOp; //Name of SDVariables as input
|
||||
protected List<String> outputsOfOp; //Name of SDVariables as output
|
||||
protected List<String> controlDeps; //Name of SDVariables as control dependencies (not data inputs, but need to be available before exec)
|
||||
protected List<String> varControlDeps; //Variables (constants, placeholders, etc) that are control dependencies for this op
|
||||
protected List<String> controlDepFor; //Name of the variables that this op is a control dependency for
|
||||
}
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
package org.nd4j.autodiff.samediff.internal;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
|
||||
import java.io.Closeable;
|
||||
|
||||
/**
|
||||
* SessionMemMgr - aka "Session Memory Manager" is responsible for allocating, managing, and deallocating memory used
|
||||
* during SameDiff execution.<br>
|
||||
* This interface allows different memory management strategies to be used, abstracted away from the actual graph
|
||||
* execution logic
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public interface SessionMemMgr extends Closeable {
|
||||
|
||||
/**
|
||||
* Allocate an array with the specified datatype and shape.<br>
|
||||
* NOTE: This array should be assumed to be uninitialized - i.e., contains random values.
|
||||
*
|
||||
* @param detached If true: the array is safe to return outside of the SameDiff session run (for example, the array
|
||||
* is one that may be returned to the user)
|
||||
* @param dataType Datatype of the returned array
|
||||
* @param shape Array shape
|
||||
* @return The newly allocated (uninitialized) array
|
||||
*/
|
||||
INDArray allocate(boolean detached, DataType dataType, long... shape);
|
||||
|
||||
/**
|
||||
* As per {@link #allocate(boolean, DataType, long...)} but from a LongShapeDescriptor instead
|
||||
*/
|
||||
INDArray allocate(boolean detached, LongShapeDescriptor descriptor);
|
||||
|
||||
/**
|
||||
* Allocate an uninitialized array with the same datatype and shape as the specified array
|
||||
*/
|
||||
INDArray ulike(INDArray arr);
|
||||
|
||||
/**
|
||||
* Duplicate the specified array, to an array that is managed/allocated by the session memory manager
|
||||
*/
|
||||
INDArray dup(INDArray arr);
|
||||
|
||||
/**
|
||||
* Release the array. All arrays allocated via one of the allocate methods should be returned here once they are no
|
||||
* longer used, and all references to them should be cleared.
|
||||
* After calling release, anything could occur to the array - deallocated, workspace closed, reused, etc.
|
||||
*
|
||||
* @param array The array that can be released
|
||||
*/
|
||||
void release(INDArray array);
|
||||
|
||||
/**
|
||||
* Close the session memory manager and clean up any memory / resources, if any
|
||||
*/
|
||||
void close();
|
||||
|
||||
}
|
|
@ -0,0 +1,232 @@
|
|||
package org.nd4j.autodiff.samediff.internal;
|
||||
|
||||
import com.sun.prism.paint.Gradient;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.Listener;
|
||||
import org.nd4j.autodiff.listeners.Loss;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.learning.GradientUpdater;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
import org.nd4j.linalg.primitives.AtomicDouble;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* TrainingSession extends InferenceSession, to add training-specific functionality:<br>
|
||||
* - Application of regularization (L1, L2, weight decay etc)<br>
|
||||
* - Inline updating of variables, using updater/optimizer (Adam, Nesterov, SGD, etc)<br>
|
||||
* - Calculation of regularization scores (Score for L1, L2, etc)
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Slf4j
|
||||
public class TrainingSession extends InferenceSession {
|
||||
|
||||
protected TrainingConfig config;
|
||||
protected Map<String, String> gradVarToVarMap;
|
||||
protected Map<String, GradientUpdater> updaters;
|
||||
protected Map<String, Integer> lossVarsToLossIdx;
|
||||
protected double[] currIterLoss;
|
||||
protected Map<Class<?>, AtomicDouble> currIterRegLoss;
|
||||
protected List<Listener> listeners;
|
||||
|
||||
|
||||
public TrainingSession(SameDiff sameDiff) {
|
||||
super(sameDiff);
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform one iteration of training - i.e., do forward and backward passes, and update the parameters
|
||||
*
|
||||
* @param config Training configuration
|
||||
* @param placeholders Current placeholders
|
||||
* @param paramsToTrain Set of parameters that will be trained
|
||||
* @param updaters Current updater state
|
||||
* @param batch Current data/batch (mainly for listeners, should have already been converted to placeholders map)
|
||||
* @param lossVariables Loss variables (names)
|
||||
* @param listeners Listeners (if any)
|
||||
* @param at Current epoch, iteration, etc
|
||||
* @return The Loss at the current iteration
|
||||
*/
|
||||
public Loss trainingIteration(TrainingConfig config, Map<String, INDArray> placeholders, Set<String> paramsToTrain, Map<String, GradientUpdater> updaters,
|
||||
MultiDataSet batch, List<String> lossVariables, List<Listener> listeners, At at) {
|
||||
this.config = config;
|
||||
this.updaters = updaters;
|
||||
|
||||
//Preprocess listeners, get the relevant ones
|
||||
if (listeners == null) {
|
||||
this.listeners = null;
|
||||
} else {
|
||||
List<Listener> filtered = new ArrayList<>();
|
||||
for (Listener l : listeners) {
|
||||
if (l.isActive(at.operation())) {
|
||||
filtered.add(l);
|
||||
}
|
||||
}
|
||||
this.listeners = filtered.isEmpty() ? null : filtered;
|
||||
}
|
||||
|
||||
List<String> requiredActivations = new ArrayList<>();
|
||||
gradVarToVarMap = new HashMap<>(); //Key: gradient variable. Value: variable that the key is gradient for
|
||||
for (String s : paramsToTrain) {
|
||||
Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s);
|
||||
SDVariable v = sameDiff.getVariable(s);
|
||||
Preconditions.checkState(v.getVariableType() == VariableType.VARIABLE, "Can only train VARIABLE type variable - \"%s\" has type %s",
|
||||
s, v.getVariableType());
|
||||
SDVariable grad = sameDiff.getVariable(s).getGradient();
|
||||
if (grad == null) {
|
||||
//In some cases, a variable won't actually impact the loss value, and hence won't have a gradient associated with it
|
||||
//For example: floatVar -> cast to integer -> cast to float -> sum -> loss
|
||||
//In this case, the gradient of floatVar isn't defined (due to no floating point connection to the loss)
|
||||
continue;
|
||||
}
|
||||
|
||||
requiredActivations.add(grad.getVarName());
|
||||
|
||||
gradVarToVarMap.put(grad.getVarName(), s);
|
||||
}
|
||||
|
||||
//Set up losses
|
||||
lossVarsToLossIdx = new LinkedHashMap<>();
|
||||
List<String> lossVars;
|
||||
currIterLoss = new double[lossVariables.size()];
|
||||
currIterRegLoss = new HashMap<>();
|
||||
for (int i = 0; i < lossVariables.size(); i++) {
|
||||
lossVarsToLossIdx.put(lossVariables.get(i), i);
|
||||
}
|
||||
|
||||
//Do training iteration
|
||||
List<String> outputVars = new ArrayList<>(gradVarToVarMap.keySet()); //TODO this should be empty, and grads calculated in requiredActivations
|
||||
Map<String, INDArray> m = output(outputVars, placeholders, batch, requiredActivations, listeners, at);
|
||||
|
||||
|
||||
double[] finalLoss = new double[currIterLoss.length + currIterRegLoss.size()];
|
||||
System.arraycopy(currIterLoss, 0, finalLoss, 0, currIterLoss.length);
|
||||
if (currIterRegLoss.size() > 0) {
|
||||
lossVars = new ArrayList<>(lossVariables.size() + currIterRegLoss.size());
|
||||
lossVars.addAll(lossVariables);
|
||||
int s = currIterRegLoss.size();
|
||||
//Collect regularization losses
|
||||
for (Map.Entry<Class<?>, AtomicDouble> entry : currIterRegLoss.entrySet()) {
|
||||
lossVars.add(entry.getKey().getSimpleName());
|
||||
finalLoss[s] = entry.getValue().get();
|
||||
}
|
||||
} else {
|
||||
lossVars = lossVariables;
|
||||
}
|
||||
|
||||
Loss loss = new Loss(lossVars, finalLoss);
|
||||
if (listeners != null) {
|
||||
for (Listener l : listeners) {
|
||||
if (l.isActive(Operation.TRAINING)) {
|
||||
l.iterationDone(sameDiff, at, batch, loss);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return loss;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables) {
|
||||
//Get outputs from InferenceSession
|
||||
INDArray[] out = super.getOutputs(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables);
|
||||
|
||||
List<String> outputs = op.getOutputsOfOp();
|
||||
int outIdx = 0;
|
||||
for (String s : outputs) {
|
||||
//If this is a loss variable - record it
|
||||
if (lossVarsToLossIdx.containsKey(s)) {
|
||||
int lossIdx = lossVarsToLossIdx.get(s);
|
||||
INDArray arr = out[outIdx];
|
||||
double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue();
|
||||
currIterLoss[lossIdx] += l;
|
||||
}
|
||||
|
||||
//If this is a gradient variable - apply the updater and update the parameter array in-line
|
||||
if (gradVarToVarMap.containsKey(s)) {
|
||||
String varName = gradVarToVarMap.get(s);
|
||||
//log.info("Calculated gradient for variable \"{}\": (grad var name: \"{}\")", varName, s);
|
||||
|
||||
Variable gradVar = sameDiff.getVariables().get(s);
|
||||
if (gradVar.getInputsForOp() != null && gradVar.getInputsForOp().isEmpty()) {
|
||||
//Should be rare, and we should handle this by tracking dependencies, and only update when safe
|
||||
// (i.e., dependency tracking)
|
||||
throw new IllegalStateException("Op depends on gradient variable: " + s + " for variable " + varName);
|
||||
}
|
||||
|
||||
GradientUpdater u = updaters.get(varName);
|
||||
Preconditions.checkState(u != null, "No updater found for variable \"%s\"", varName);
|
||||
|
||||
Variable var = sameDiff.getVariables().get(varName);
|
||||
INDArray gradArr = out[outIdx];
|
||||
INDArray paramArr = var.getVariable().getArr();
|
||||
|
||||
//Pre-updater regularization (L1, L2)
|
||||
List<Regularization> r = config.getRegularization();
|
||||
if (r != null && r.size() > 0) {
|
||||
double lr = config.getUpdater().hasLearningRate() ? config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0;
|
||||
for (Regularization reg : r) {
|
||||
if (reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) {
|
||||
if (this.listeners != null) {
|
||||
double score = reg.score(paramArr, at.iteration(), at.epoch());
|
||||
if (!currIterRegLoss.containsKey(reg.getClass())) {
|
||||
currIterRegLoss.put(reg.getClass(), new AtomicDouble());
|
||||
}
|
||||
currIterRegLoss.get(reg.getClass()).addAndGet(score);
|
||||
}
|
||||
reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
u.applyUpdater(gradArr, at.iteration(), at.epoch());
|
||||
|
||||
//Post-apply regularization (weight decay)
|
||||
if (r != null && r.size() > 0) {
|
||||
double lr = config.getUpdater().hasLearningRate() ? config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0;
|
||||
for (Regularization reg : r) {
|
||||
if (reg.applyStep() == Regularization.ApplyStep.POST_UPDATER) {
|
||||
if (this.listeners != null) {
|
||||
double score = reg.score(paramArr, at.iteration(), at.epoch());
|
||||
if (!currIterRegLoss.containsKey(reg.getClass())) {
|
||||
currIterRegLoss.put(reg.getClass(), new AtomicDouble());
|
||||
}
|
||||
currIterRegLoss.get(reg.getClass()).addAndGet(score);
|
||||
}
|
||||
reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (listeners != null) {
|
||||
for (Listener l : listeners) {
|
||||
if (l.isActive(at.operation()))
|
||||
l.preUpdate(sameDiff, at, var, gradArr);
|
||||
}
|
||||
}
|
||||
|
||||
//Update:
|
||||
if (config.isMinimize()) {
|
||||
paramArr.subi(gradArr);
|
||||
} else {
|
||||
paramArr.addi(gradArr);
|
||||
}
|
||||
log.trace("Applied updater to gradient and updated variable: {}", varName);
|
||||
}
|
||||
|
||||
outIdx++;
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
}
|
|
@ -35,8 +35,7 @@ public class Variable {
|
|||
protected List<String> controlDepsForOp; //if a op control dependency (x -> opY) exists, then "opY" will be in this list
|
||||
protected List<String> controlDepsForVar; //if a variable control dependency (x -> varY) exists, then "varY" will be in this list
|
||||
protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
||||
protected List<String> controlDeps; //Control dependencies: name of variables that must be available before this variable is considered available for execution
|
||||
protected int outputOfOpIdx; //Index of the output for the op (say, variable is output number 2 of op "outputOfOp")
|
||||
protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution
|
||||
protected SDVariable gradient; //Variable corresponding to the gradient of this variable
|
||||
protected int variableIndex = -1;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
package org.nd4j.autodiff.samediff.internal.memory;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
* Abstract memory manager, that implements ulike and dup methods using the underlying allocate methods
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public abstract class AbstractMemoryMgr implements SessionMemMgr {
|
||||
|
||||
@Override
|
||||
public INDArray ulike(@NonNull INDArray arr) {
|
||||
return allocate(false, arr.dataType(), arr.shape());
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray dup(@NonNull INDArray arr) {
|
||||
INDArray out = ulike(arr);
|
||||
out.assign(arr);
|
||||
return out;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package org.nd4j.autodiff.samediff.internal.memory;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* A simple memory management strategy that deallocates memory as soon as it is no longer needed.<br>
|
||||
* This should result in a minimal amount of memory, but will have some overhead - notably, the cost of deallocating
|
||||
* and reallocating memory all the time.
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Slf4j
|
||||
public class ArrayCloseMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr {
|
||||
|
||||
@Override
|
||||
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
|
||||
return Nd4j.createUninitialized(dataType, shape);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
|
||||
return Nd4j.create(descriptor, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void release(@NonNull INDArray array) {
|
||||
if (!array.wasClosed() && array.closeable()) {
|
||||
array.close();
|
||||
log.trace("Closed array (deallocated) - id={}", array.getId());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
//No-op
|
||||
}
|
||||
}
|
|
@ -0,0 +1,168 @@
|
|||
package org.nd4j.autodiff.samediff.internal.memory;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
import org.nd4j.autodiff.samediff.internal.DependencyList;
|
||||
import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker;
|
||||
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* A {@link SessionMemMgr} that wraps an existing memory manager, to ensure that:<br>
|
||||
* - All arrays that are supposed to be closed, have been closed<br>
|
||||
* - Arrays are only passed to the close method exactly one (unless they are requested outputs)<br>
|
||||
* - Arrays that are passed to the close method were originally allocated by the session memory manager<br>
|
||||
* <br>
|
||||
* How to use:<br>
|
||||
* 1. Perform an inference or training iteration, as normal<br>
|
||||
* 2. Call {@link #assertAllReleasedExcept(Collection)} with the output arrays<br>
|
||||
* <p>
|
||||
* NOTE: This is intended for debugging and testing only
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Slf4j
|
||||
public class CloseValidationMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr {
|
||||
|
||||
private final SameDiff sd;
|
||||
private final SessionMemMgr underlying;
|
||||
private final Map<INDArray, Boolean> released = new IdentityHashMap<>();
|
||||
|
||||
public CloseValidationMemoryMgr(SameDiff sd, SessionMemMgr underlying) {
|
||||
this.sd = sd;
|
||||
this.underlying = underlying;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
|
||||
INDArray out = underlying.allocate(detached, dataType, shape);
|
||||
released.put(out, false);
|
||||
return out;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
|
||||
INDArray out = underlying.allocate(detached, descriptor);
|
||||
released.put(out, false);
|
||||
return out;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void release(INDArray array) {
|
||||
Preconditions.checkState(released.containsKey(array), "Attempting to release an array that was not allocated by" +
|
||||
" this memory manager: id=%s", array.getId());
|
||||
if (released.get(array)) {
|
||||
//Already released
|
||||
InferenceSession is = sd.getSessions().get(Thread.currentThread().getId());
|
||||
IdentityDependencyTracker<INDArray, InferenceSession.Dep> arrayUseTracker = is.getArrayUseTracker();
|
||||
DependencyList<INDArray, InferenceSession.Dep> dl = arrayUseTracker.getDependencies(array);
|
||||
System.out.println(dl);
|
||||
if (dl.getDependencies() != null) {
|
||||
for (InferenceSession.Dep d : dl.getDependencies()) {
|
||||
System.out.println(d + ": " + arrayUseTracker.isSatisfied(d));
|
||||
}
|
||||
}
|
||||
if (dl.getOrDependencies() != null) {
|
||||
for (Pair<InferenceSession.Dep, InferenceSession.Dep> p : dl.getOrDependencies()) {
|
||||
System.out.println(p + " - (" + arrayUseTracker.isSatisfied(p.getFirst()) + "," + arrayUseTracker.isSatisfied(p.getSecond()));
|
||||
}
|
||||
}
|
||||
}
|
||||
Preconditions.checkState(!released.get(array), "Attempting to release an array that was already deallocated by" +
|
||||
" an earlier release call to this memory manager: id=%s", array.getId());
|
||||
log.trace("Released array: id = {}", array.getId());
|
||||
released.put(array, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
underlying.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check that all arrays have been released (after an inference call) except for the specified arrays.
|
||||
*
|
||||
* @param except Arrays that should not have been closed (usually network outputs)
|
||||
*/
|
||||
public void assertAllReleasedExcept(@NonNull Collection<INDArray> except) {
|
||||
Set<INDArray> allVarPhConst = null;
|
||||
|
||||
for (INDArray arr : except) {
|
||||
if (!released.containsKey(arr)) {
|
||||
//Check if constant, variable or placeholder - maybe user requested that out
|
||||
if (allVarPhConst == null)
|
||||
allVarPhConst = identitySetAllConstPhVar();
|
||||
if (allVarPhConst.contains(arr))
|
||||
continue; //OK - output is a constant, variable or placeholder, hence it's fine it's not allocated by the memory manager
|
||||
|
||||
throw new IllegalStateException("Array " + arr.getId() + " was not originally allocated by the memory manager");
|
||||
}
|
||||
|
||||
boolean released = this.released.get(arr);
|
||||
if (released) {
|
||||
throw new IllegalStateException("Specified output array (id=" + arr.getId() + ") should not have been deallocated but was");
|
||||
}
|
||||
}
|
||||
|
||||
Set<INDArray> exceptSet = Collections.newSetFromMap(new IdentityHashMap<INDArray, Boolean>());
|
||||
exceptSet.addAll(except);
|
||||
|
||||
int numNotClosed = 0;
|
||||
Set<INDArray> notReleased = Collections.newSetFromMap(new IdentityHashMap<INDArray, Boolean>());
|
||||
InferenceSession is = sd.getSessions().get(Thread.currentThread().getId());
|
||||
IdentityDependencyTracker<INDArray, InferenceSession.Dep> arrayUseTracker = is.getArrayUseTracker();
|
||||
for (Map.Entry<INDArray, Boolean> e : released.entrySet()) {
|
||||
INDArray a = e.getKey();
|
||||
if (!exceptSet.contains(a)) {
|
||||
boolean b = e.getValue();
|
||||
if (!b) {
|
||||
notReleased.add(a);
|
||||
numNotClosed++;
|
||||
log.info("Not released: array id {}", a.getId());
|
||||
DependencyList<INDArray, InferenceSession.Dep> list = arrayUseTracker.getDependencies(a);
|
||||
List<InferenceSession.Dep> l = list.getDependencies();
|
||||
List<Pair<InferenceSession.Dep, InferenceSession.Dep>> l2 = list.getOrDependencies();
|
||||
if (l != null) {
|
||||
for (InferenceSession.Dep d : l) {
|
||||
if (!arrayUseTracker.isSatisfied(d)) {
|
||||
log.info(" Not satisfied: {}", d);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (l2 != null) {
|
||||
for (Pair<InferenceSession.Dep, InferenceSession.Dep> d : l2) {
|
||||
if (!arrayUseTracker.isSatisfied(d.getFirst()) && !arrayUseTracker.isSatisfied(d.getSecond())) {
|
||||
log.info(" Not satisfied: {}", d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (numNotClosed > 0) {
|
||||
System.out.println(sd.summary());
|
||||
throw new IllegalStateException(numNotClosed + " arrays were not released but should have been");
|
||||
}
|
||||
}
|
||||
|
||||
protected Set<INDArray> identitySetAllConstPhVar() {
|
||||
Set<INDArray> set = Collections.newSetFromMap(new IdentityHashMap<INDArray, Boolean>());
|
||||
for (SDVariable v : sd.variables()) {
|
||||
if (v.getVariableType() == VariableType.VARIABLE || v.getVariableType() == VariableType.CONSTANT || v.getVariableType() == VariableType.PLACEHOLDER) {
|
||||
set.add(v.getArr());
|
||||
}
|
||||
}
|
||||
return set;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
package org.nd4j.autodiff.samediff.internal.memory;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* A simple "no-op" memory manager that relies on JVM garbage collector for memory management.
|
||||
* Assuming other references have been cleared (they should have been) the arrays will be cleaned up by the
|
||||
* garbage collector at some point.
|
||||
*
|
||||
* This memory management strategy is not recommended for performance or memory reasons, and should only be used
|
||||
* for testing and debugging purposes
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class NoOpMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr {
|
||||
|
||||
@Override
|
||||
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
|
||||
return Nd4j.createUninitialized(dataType, shape);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
|
||||
return Nd4j.create(descriptor, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void release(@NonNull INDArray array) {
|
||||
//No-op, rely on GC to clear arrays
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
//No-op
|
||||
}
|
||||
|
||||
}
|
|
@ -90,10 +90,10 @@ public class SDNN extends SDOps {
|
|||
}
|
||||
|
||||
/**
|
||||
* @see #biasAdd(String, SDVariable, SDVariable)
|
||||
* @see #biasAdd(String, SDVariable, SDVariable, boolean)
|
||||
*/
|
||||
public SDVariable biasAdd(SDVariable input, SDVariable bias) {
|
||||
return biasAdd(null, input, bias);
|
||||
public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) {
|
||||
return biasAdd(null, input, bias, nchw);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -102,12 +102,14 @@ public class SDNN extends SDOps {
|
|||
* @param name Name of the output variable
|
||||
* @param input 4d input variable
|
||||
* @param bias 1d bias
|
||||
* @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels].
|
||||
* Unused for 2d inputs
|
||||
* @return Output variable
|
||||
*/
|
||||
public SDVariable biasAdd(String name, SDVariable input, SDVariable bias) {
|
||||
public SDVariable biasAdd(String name, SDVariable input, SDVariable bias, boolean nchw) {
|
||||
validateFloatingPoint("biasAdd", "input", input);
|
||||
validateFloatingPoint("biasAdd", "bias", bias);
|
||||
SDVariable ret = f().biasAdd(input, bias);
|
||||
SDVariable ret = f().biasAdd(input, bias, nchw);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.autodiff.samediff.serde;
|
||||
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.shade.guava.primitives.Ints;
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import java.nio.ByteOrder;
|
||||
|
@ -847,6 +848,28 @@ public class FlatBuffersMapper {
|
|||
}
|
||||
int outTypesOffset = FlatNode.createOutputTypesVector(bufferBuilder, outTypes);
|
||||
|
||||
//Control dependencies:
|
||||
SameDiffOp sdo = sameDiff.getOps().get(node.getOwnName());
|
||||
|
||||
int opCds = 0;
|
||||
int[] opCdsArr = mapOrNull(sdo.getControlDeps(), bufferBuilder);
|
||||
if(opCdsArr != null){
|
||||
opCds = FlatNode.createControlDepsVector(bufferBuilder, opCdsArr);
|
||||
}
|
||||
|
||||
int varCds = 0;
|
||||
int[] varCdsArr = mapOrNull(sdo.getVarControlDeps(), bufferBuilder);
|
||||
if(varCdsArr != null){
|
||||
varCds = FlatNode.createVarControlDepsVector(bufferBuilder, varCdsArr);
|
||||
}
|
||||
|
||||
int cdsFor = 0;
|
||||
int[] cdsForArr = mapOrNull(sdo.getControlDepFor(), bufferBuilder);
|
||||
if(cdsForArr != null){
|
||||
cdsFor = FlatNode.createControlDepForVector(bufferBuilder, cdsForArr);
|
||||
}
|
||||
|
||||
|
||||
int flatNode = FlatNode.createFlatNode(
|
||||
bufferBuilder,
|
||||
ownId,
|
||||
|
@ -867,12 +890,26 @@ public class FlatBuffersMapper {
|
|||
outVarNamesOffset,
|
||||
opNameOffset,
|
||||
outTypesOffset, //Output types
|
||||
scalar
|
||||
scalar,
|
||||
opCds,
|
||||
varCds,
|
||||
cdsFor
|
||||
);
|
||||
|
||||
return flatNode;
|
||||
}
|
||||
|
||||
public static int[] mapOrNull(List<String> list, FlatBufferBuilder fbb){
|
||||
if(list == null)
|
||||
return null;
|
||||
int[] out = new int[list.size()];
|
||||
int i=0;
|
||||
for(String s : list){
|
||||
out[i++] = fbb.createString(s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
public static DifferentialFunction cloneViaSerialize(SameDiff sd, DifferentialFunction df ){
|
||||
Map<String,Integer> nameToIdxMap = new HashMap<>();
|
||||
int count = 0;
|
||||
|
|
|
@ -131,12 +131,12 @@ public class GradCheckUtil {
|
|||
// in this case, gradients of x and y are all 0 too
|
||||
|
||||
//Collect variables to get gradients for - we want placeholders AND variables
|
||||
Set<String> gradVarNames = new HashSet<>();
|
||||
Set<String> varsNeedingGrads = new HashSet<>();
|
||||
for(Variable v : sd.getVariables().values()){
|
||||
if(v.getVariable().dataType().isFPType() && (v.getVariable().getVariableType() == VariableType.VARIABLE || v.getVariable().getVariableType() == VariableType.PLACEHOLDER)){
|
||||
SDVariable g = v.getVariable().getGradient();
|
||||
Preconditions.checkNotNull(g, "No gradient variable found for variable %s", v.getVariable());
|
||||
gradVarNames.add(g.getVarName());
|
||||
varsNeedingGrads.add(v.getName());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -164,7 +164,7 @@ public class GradCheckUtil {
|
|||
}
|
||||
|
||||
|
||||
sd.execBackwards(placeholderValues, new ArrayList<>(gradVarNames));
|
||||
Map<String,INDArray> gm = sd.calculateGradients(placeholderValues, varsNeedingGrads);
|
||||
|
||||
//Remove listener, to reduce overhead
|
||||
sd.getListeners().remove(listenerIdx);
|
||||
|
@ -183,11 +183,11 @@ public class GradCheckUtil {
|
|||
if(g == null){
|
||||
throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\"");
|
||||
}
|
||||
INDArray ga = g.getArr();
|
||||
INDArray ga = gm.get(v.getVarName());
|
||||
if(ga == null){
|
||||
throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName());
|
||||
}
|
||||
if(!Arrays.equals(v.getArr().shape(), g.getArr().shape())){
|
||||
if(!Arrays.equals(v.getArr().shape(), ga.shape())){
|
||||
throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" +
|
||||
v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " +
|
||||
Arrays.toString(ga.shape()));
|
||||
|
@ -408,18 +408,18 @@ public class GradCheckUtil {
|
|||
|
||||
//Collect names of variables to get gradients for - i.e., the names of the GRADIENT variables for the specified activations
|
||||
sd.createGradFunction();
|
||||
Set<String> gradVarNames = new HashSet<>();
|
||||
Set<String> varsRequiringGrads = new HashSet<>();
|
||||
for(String s : actGrads){
|
||||
SDVariable grad = sd.getVariable(s).gradient();
|
||||
Preconditions.checkState( grad != null,"Could not get gradient for activation \"%s\": gradient variable is null", s);
|
||||
gradVarNames.add(grad.getVarName());
|
||||
varsRequiringGrads.add(s);
|
||||
}
|
||||
|
||||
//Calculate analytical gradients
|
||||
sd.execBackwards(config.getPlaceholderValues(), new ArrayList<>(gradVarNames));
|
||||
Map<String,INDArray> grads = sd.calculateGradients(config.getPlaceholderValues(), new ArrayList<>(varsRequiringGrads));
|
||||
Map<String,INDArray> gradientsForAct = new HashMap<>();
|
||||
for(String s : actGrads){
|
||||
INDArray arr = sd.getVariable(s).gradient().getArr();
|
||||
INDArray arr = grads.get(s);
|
||||
Preconditions.checkState(arr != null, "No activation gradient array for variable \"%s\"", s);
|
||||
gradientsForAct.put(s, arr.dup());
|
||||
}
|
||||
|
|
|
@ -190,11 +190,13 @@ public class OpValidation {
|
|||
//Check forward pass:
|
||||
if (testCase.fwdTestFns() != null && testCase.fwdTestFns().size() > 0) {
|
||||
SameDiff sd = testCase.sameDiff();
|
||||
|
||||
//Collect variables we need outputs for...
|
||||
Set<String> reqVars = testCase.fwdTestFns().keySet();
|
||||
|
||||
Map<String,INDArray> out;
|
||||
try {
|
||||
if(testCase.placeholderValues() != null){
|
||||
sd.resolveVariablesWith(testCase.placeholderValues());
|
||||
}
|
||||
sd.exec(null, sd.outputs());
|
||||
out = sd.output(testCase.placeholderValues(), new ArrayList<>(reqVars));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Error during forward pass testing" + testCase.testNameErrMsg(), e);
|
||||
}
|
||||
|
@ -206,7 +208,7 @@ public class OpValidation {
|
|||
e.getKey() + "\" but SameDiff instance does not have a variable for this name" + testCase.testNameErrMsg());
|
||||
}
|
||||
|
||||
INDArray actual = v.getArr();
|
||||
INDArray actual = out.get(v.getVarName());
|
||||
if (actual == null) {
|
||||
throw new IllegalStateException("Null INDArray after forward pass for variable \"" + e.getKey() + "\"");
|
||||
}
|
||||
|
@ -291,6 +293,12 @@ public class OpValidation {
|
|||
Preconditions.checkState((orig.getControlDeps() == null) == (des.getControlDeps() == null), "Control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps());
|
||||
Preconditions.checkState(orig.getControlDeps() == null || orig.getControlDeps().equals(des.getControlDeps()), "Control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps());
|
||||
|
||||
Preconditions.checkState((orig.getVarControlDeps() == null) == (des.getVarControlDeps() == null), "Op variable control dependencies differ: %s vs. %s", orig.getVarControlDeps(), des.getVarControlDeps());
|
||||
Preconditions.checkState(orig.getVarControlDeps() == null || orig.getVarControlDeps().equals(des.getVarControlDeps()), "Op variable control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps());
|
||||
|
||||
Preconditions.checkState((orig.getControlDepFor() == null) == (des.getControlDepFor() == null), "Op control dependencies for list differ: %s vs. %s", orig.getControlDepFor(), des.getControlDepFor());
|
||||
Preconditions.checkState(orig.getControlDepFor() == null || orig.getControlDepFor().equals(des.getControlDepFor()), "Op variable control dependencies differ: %s vs. %s", orig.getControlDepFor(), des.getControlDepFor());
|
||||
|
||||
Preconditions.checkState(orig.getOp().getClass() == des.getOp().getClass(), "Classes differ: %s v. %s", orig.getOp().getClass(), des.getOp().getClass());
|
||||
}
|
||||
|
||||
|
@ -317,6 +325,11 @@ public class OpValidation {
|
|||
Map<String,Variable> varsBefore = original.getVariables();
|
||||
Map<String,Variable> varsAfter = deserialized.getVariables();
|
||||
Preconditions.checkState(varsBefore.keySet().equals(varsAfter.keySet()), "Variable keysets do not match: %s vs %s", varsBefore.keySet(), varsAfter.keySet());
|
||||
|
||||
// System.out.println(original.summary());
|
||||
// System.out.println("\n\n\n\n");
|
||||
// System.out.println(deserialized.summary());
|
||||
|
||||
for(String s : varsBefore.keySet()){
|
||||
Variable vB = varsBefore.get(s);
|
||||
Variable vA = varsAfter.get(s);
|
||||
|
@ -324,13 +337,15 @@ public class OpValidation {
|
|||
Preconditions.checkState(vB.getVariable().getVariableType() == vA.getVariable().getVariableType(),
|
||||
"Variable types do not match: %s - %s vs %s", s, vB.getVariable().getVariableType(), vA.getVariable().getVariableType());
|
||||
|
||||
Preconditions.checkState((vB.getInputsForOp() == null) == (vA.getInputsForOp() == null), "Input to ops differ: %s vs. %s", vB.getInputsForOp(), vA.getInputsForOp());
|
||||
Preconditions.checkState(vB.getInputsForOp() == null || vB.getInputsForOp().equals(vA.getInputsForOp()), "Inputs differ: %s vs. %s", vB.getInputsForOp(), vA.getInputsForOp());
|
||||
equalConsideringNull(vB.getInputsForOp(), vA.getInputsForOp(), "%s - Input to ops differ: %s vs. %s", s, vB.getInputsForOp(), vA.getInputsForOp());
|
||||
|
||||
Preconditions.checkState((vB.getOutputOfOp() == null && vA.getOutputOfOp() == null) || vB.getOutputOfOp().equals(vA.getOutputOfOp()), "Output of op differ: %s vs. %s", vB.getOutputOfOp(), vA.getOutputOfOp());
|
||||
Preconditions.checkState((vB.getOutputOfOp() == null && vA.getOutputOfOp() == null) || vB.getOutputOfOp().equals(vA.getOutputOfOp()), "%s - Output of op differ: %s vs. %s", s, vB.getOutputOfOp(), vA.getOutputOfOp());
|
||||
|
||||
Preconditions.checkState((vB.getControlDeps() == null) == (vA.getControlDeps() == null), "Control dependencies differ: %s vs. %s", vB.getControlDeps(), vA.getControlDeps());
|
||||
Preconditions.checkState(vB.getControlDeps() == null || vB.getControlDeps().equals(vA.getControlDeps()), "Control dependencies differ: %s vs. %s", vB.getControlDeps(), vA.getControlDeps());
|
||||
equalConsideringNull(vB.getControlDeps(), vA.getControlDeps(), "%s - Control dependencies differ: %s vs. %s", s, vB.getControlDeps(), vA.getControlDeps());
|
||||
|
||||
equalConsideringNull(vB.getControlDepsForOp(), vA.getControlDepsForOp(), "%s - Control dependencies for ops differ: %s vs. %s", s, vB.getControlDepsForOp(), vA.getControlDepsForOp());
|
||||
|
||||
equalConsideringNull(vB.getControlDepsForVar(), vA.getControlDepsForVar(), "%s - Control dependencies for vars differ: %s vs. %s", s, vB.getControlDepsForVar(), vA.getControlDepsForVar());
|
||||
}
|
||||
|
||||
//Check loss variables:
|
||||
|
@ -343,7 +358,7 @@ public class OpValidation {
|
|||
lossVarBefore, lossVarAfter);
|
||||
}
|
||||
|
||||
|
||||
if(tc.fwdTestFns() != null && !tc.fwdTestFns().isEmpty()) {
|
||||
//Finally: check execution/output
|
||||
Map<String,INDArray> outOrig = original.outputAll(tc.placeholderValues());
|
||||
Map<String,INDArray> outDe = deserialized.outputAll(tc.placeholderValues());
|
||||
|
@ -387,6 +402,17 @@ public class OpValidation {
|
|||
Preconditions.checkState(err == null, "Variable result (%s) failed check - \"%ndSInfo\" vs \"%ndSInfo\" - %nd10 vs %nd10\nError:%s", s, orig, deser, orig, deser, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected static void equalConsideringNull(List<String> l1, List<String> l2, String msg, Object... args){
|
||||
//Consider null and length 0 list to be equal (semantically they mean the same thing)
|
||||
boolean empty1 = l1 == null || l1.isEmpty();
|
||||
boolean empty2 = l2 == null || l2.isEmpty();
|
||||
if(empty1 && empty2){
|
||||
return;
|
||||
}
|
||||
Preconditions.checkState(l1 == null || l1.equals(l2), msg, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate the outputs of a single op
|
||||
|
|
|
@ -25,6 +25,7 @@ public class NonInplaceValidationListener extends BaseListener {
|
|||
private static AtomicInteger failCounter = new AtomicInteger();
|
||||
|
||||
protected INDArray[] opInputs;
|
||||
protected INDArray[] opInputsOrig;
|
||||
|
||||
public NonInplaceValidationListener(){
|
||||
useCounter.getAndIncrement();
|
||||
|
@ -42,14 +43,18 @@ public class NonInplaceValidationListener extends BaseListener {
|
|||
//No input op
|
||||
return;
|
||||
} else if(o.y() == null){
|
||||
opInputsOrig = new INDArray[]{o.x()};
|
||||
opInputs = new INDArray[]{o.x().dup()};
|
||||
} else {
|
||||
opInputsOrig = new INDArray[]{o.x(), o.y()};
|
||||
opInputs = new INDArray[]{o.x().dup(), o.y().dup()};
|
||||
}
|
||||
} else if(op.getOp() instanceof DynamicCustomOp){
|
||||
INDArray[] arr = ((DynamicCustomOp) op.getOp()).inputArguments();
|
||||
opInputs = new INDArray[arr.length];
|
||||
opInputsOrig = new INDArray[arr.length];
|
||||
for( int i=0; i<arr.length; i++ ){
|
||||
opInputsOrig[i] = arr[i];
|
||||
opInputs[i] = arr[i].dup();
|
||||
}
|
||||
} else {
|
||||
|
@ -64,23 +69,6 @@ public class NonInplaceValidationListener extends BaseListener {
|
|||
return;
|
||||
}
|
||||
|
||||
INDArray[] inputsAfter;
|
||||
if(op.getOp() instanceof Op){
|
||||
Op o = (Op)op.getOp();
|
||||
if(o.x() == null){
|
||||
//No input op
|
||||
return;
|
||||
} else if(o.y() == null){
|
||||
inputsAfter = new INDArray[]{o.x()};
|
||||
} else {
|
||||
inputsAfter = new INDArray[]{o.x(), o.y()};
|
||||
}
|
||||
} else if(op.getOp() instanceof DynamicCustomOp){
|
||||
inputsAfter = ((DynamicCustomOp) op.getOp()).inputArguments();
|
||||
} else {
|
||||
throw new IllegalStateException("Unknown op type: " + op.getOp().getClass());
|
||||
}
|
||||
|
||||
MessageDigest md;
|
||||
try {
|
||||
md = MessageDigest.getInstance("MD5");
|
||||
|
@ -93,12 +81,12 @@ public class NonInplaceValidationListener extends BaseListener {
|
|||
|
||||
//Need to hash - to ensure zero changes to input array
|
||||
byte[] before = opInputs[i].data().asBytes();
|
||||
INDArray after = inputsAfter[i];
|
||||
INDArray after = this.opInputsOrig[i];
|
||||
boolean dealloc = false;
|
||||
if(opInputs[i].ordering() != inputsAfter[i].ordering() || Arrays.equals(opInputs[i].stride(), inputsAfter[i].stride())
|
||||
|| opInputs[i].elementWiseStride() != inputsAfter[i].elementWiseStride()){
|
||||
if(opInputs[i].ordering() != opInputsOrig[i].ordering() || Arrays.equals(opInputs[i].stride(), opInputsOrig[i].stride())
|
||||
|| opInputs[i].elementWiseStride() != opInputsOrig[i].elementWiseStride()){
|
||||
//Clone if required (otherwise fails for views etc)
|
||||
after = inputsAfter[i].dup();
|
||||
after = opInputsOrig[i].dup();
|
||||
dealloc = true;
|
||||
}
|
||||
byte[] afterB = after.data().asBytes();
|
||||
|
|
|
@ -67,6 +67,12 @@ public final class FlatNode extends Table {
|
|||
public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); }
|
||||
public FlatArray scalar() { return scalar(new FlatArray()); }
|
||||
public FlatArray scalar(FlatArray obj) { int o = __offset(40); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
|
||||
public String controlDeps(int j) { int o = __offset(42); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int controlDepsLength() { int o = __offset(42); return o != 0 ? __vector_len(o) : 0; }
|
||||
public String varControlDeps(int j) { int o = __offset(44); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
|
||||
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
|
||||
|
||||
public static int createFlatNode(FlatBufferBuilder builder,
|
||||
int id,
|
||||
|
@ -87,9 +93,15 @@ public final class FlatNode extends Table {
|
|||
int outputNamesOffset,
|
||||
int opNameOffset,
|
||||
int outputTypesOffset,
|
||||
int scalarOffset) {
|
||||
builder.startObject(19);
|
||||
int scalarOffset,
|
||||
int controlDepsOffset,
|
||||
int varControlDepsOffset,
|
||||
int controlDepForOffset) {
|
||||
builder.startObject(22);
|
||||
FlatNode.addOpNum(builder, opNum);
|
||||
FlatNode.addControlDepFor(builder, controlDepForOffset);
|
||||
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
|
||||
FlatNode.addControlDeps(builder, controlDepsOffset);
|
||||
FlatNode.addScalar(builder, scalarOffset);
|
||||
FlatNode.addOutputTypes(builder, outputTypesOffset);
|
||||
FlatNode.addOpName(builder, opNameOffset);
|
||||
|
@ -111,7 +123,7 @@ public final class FlatNode extends Table {
|
|||
return FlatNode.endFlatNode(builder);
|
||||
}
|
||||
|
||||
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(19); }
|
||||
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); }
|
||||
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
|
||||
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
||||
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
|
||||
|
@ -151,6 +163,15 @@ public final class FlatNode extends Table {
|
|||
public static int createOutputTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
|
||||
public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
|
||||
public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); }
|
||||
public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(19, controlDepsOffset, 0); }
|
||||
public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static void addVarControlDeps(FlatBufferBuilder builder, int varControlDepsOffset) { builder.addOffset(20, varControlDepsOffset, 0); }
|
||||
public static int createVarControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
|
||||
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static int endFlatNode(FlatBufferBuilder builder) {
|
||||
int o = builder.endObject();
|
||||
return o;
|
||||
|
|
|
@ -29,6 +29,12 @@ public final class FlatVariable extends Table {
|
|||
public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
|
||||
public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
|
||||
public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; }
|
||||
public String controlDeps(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int controlDepsLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; }
|
||||
public String controlDepForOp(int j) { int o = __offset(20); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int controlDepForOpLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; }
|
||||
public String controlDepsForVar(int j) { int o = __offset(22); return o != 0 ? __string(__vector(o) + j * 4) : null; }
|
||||
public int controlDepsForVarLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; }
|
||||
|
||||
public static int createFlatVariable(FlatBufferBuilder builder,
|
||||
int idOffset,
|
||||
|
@ -37,8 +43,14 @@ public final class FlatVariable extends Table {
|
|||
int shapeOffset,
|
||||
int ndarrayOffset,
|
||||
int device,
|
||||
byte variabletype) {
|
||||
builder.startObject(7);
|
||||
byte variabletype,
|
||||
int controlDepsOffset,
|
||||
int controlDepForOpOffset,
|
||||
int controlDepsForVarOffset) {
|
||||
builder.startObject(10);
|
||||
FlatVariable.addControlDepsForVar(builder, controlDepsForVarOffset);
|
||||
FlatVariable.addControlDepForOp(builder, controlDepForOpOffset);
|
||||
FlatVariable.addControlDeps(builder, controlDepsOffset);
|
||||
FlatVariable.addDevice(builder, device);
|
||||
FlatVariable.addNdarray(builder, ndarrayOffset);
|
||||
FlatVariable.addShape(builder, shapeOffset);
|
||||
|
@ -49,7 +61,7 @@ public final class FlatVariable extends Table {
|
|||
return FlatVariable.endFlatVariable(builder);
|
||||
}
|
||||
|
||||
public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(7); }
|
||||
public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(10); }
|
||||
public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); }
|
||||
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
|
||||
public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); }
|
||||
|
@ -59,6 +71,15 @@ public final class FlatVariable extends Table {
|
|||
public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); }
|
||||
public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); }
|
||||
public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); }
|
||||
public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(7, controlDepsOffset, 0); }
|
||||
public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static void addControlDepForOp(FlatBufferBuilder builder, int controlDepForOpOffset) { builder.addOffset(8, controlDepForOpOffset, 0); }
|
||||
public static int createControlDepForOpVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static void addControlDepsForVar(FlatBufferBuilder builder, int controlDepsForVarOffset) { builder.addOffset(9, controlDepsForVarOffset, 0); }
|
||||
public static int createControlDepsForVarVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
|
||||
public static void startControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
|
||||
public static int endFlatVariable(FlatBufferBuilder builder) {
|
||||
int o = builder.endObject();
|
||||
return o;
|
||||
|
@ -67,3 +88,4 @@ public final class FlatVariable extends Table {
|
|||
public static void finishSizePrefixedFlatVariableBuffer(FlatBufferBuilder builder, int offset) { builder.finishSizePrefixed(offset); }
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -25,11 +25,7 @@ import org.nd4j.imports.descriptors.onnx.OnnxDescriptorParser;
|
|||
import org.nd4j.imports.descriptors.onnx.OpDescriptor;
|
||||
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
|
||||
import org.nd4j.linalg.api.ops.*;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.*;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
|
@ -370,6 +366,8 @@ public class DifferentialFunctionClassHolder {
|
|||
return Merge.class;
|
||||
case Switch.OP_NAME:
|
||||
return Switch.class;
|
||||
case LoopCond.OP_NAME:
|
||||
return LoopCond.class;
|
||||
case ExternalErrorsFunction.OP_NAME:
|
||||
return ExternalErrorsFunction.class;
|
||||
default:
|
||||
|
|
|
@ -69,13 +69,9 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan.class,
|
||||
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThanOrEqual.class,
|
||||
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastNotEqual.class,
|
||||
org.nd4j.linalg.api.ops.impl.controlflow.If.class,
|
||||
org.nd4j.linalg.api.ops.impl.controlflow.IfDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.controlflow.Select.class,
|
||||
org.nd4j.linalg.api.ops.impl.controlflow.Where.class,
|
||||
org.nd4j.linalg.api.ops.impl.controlflow.WhereNumpy.class,
|
||||
org.nd4j.linalg.api.ops.impl.controlflow.While.class,
|
||||
org.nd4j.linalg.api.ops.impl.controlflow.WhileDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter.class,
|
||||
org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit.class,
|
||||
org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond.class,
|
||||
|
|
|
@ -1,413 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.imports.graphmapper;
|
||||
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.nd4j.shade.protobuf.Message;
|
||||
import org.nd4j.shade.protobuf.TextFormat;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Base implementation for importing a graph
|
||||
*
|
||||
* @param <GRAPH_TYPE> the type of graph
|
||||
* @param <NODE_TYPE> the type of node
|
||||
* @param <ATTR_TYPE> the attribute type
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE> implements GraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE> {
|
||||
|
||||
|
||||
@Override
|
||||
public Op.Type opTypeForNode(NODE_TYPE nodeDef) {
|
||||
DifferentialFunction opWithTensorflowName = getMappedOp(getOpType(nodeDef));
|
||||
if (opWithTensorflowName == null)
|
||||
throw new NoOpNameFoundException("No op found with name " + getOpType(nodeDef));
|
||||
Op.Type type = opWithTensorflowName.opType();
|
||||
return type;
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappings) {
|
||||
val mappings = propertyMappings.get(getOpType(node));
|
||||
if (mappings == null || mappings.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
for (val entry : mappings.entrySet()) {
|
||||
mapProperty(entry.getKey(), on, node, graph, sameDiff, propertyMappings);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @param inputStream
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public SameDiff importGraph(InputStream inputStream) {
|
||||
return importGraph(inputStream, Collections.<String, OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>>emptyMap(), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SameDiff importGraph(InputStream inputStream, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
|
||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter) {
|
||||
GRAPH_TYPE def = readGraph(inputStream, opImportOverrides);
|
||||
return importGraph(def, opImportOverrides, opFilter);
|
||||
}
|
||||
|
||||
protected GRAPH_TYPE readGraph(InputStream inputStream, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides) {
|
||||
byte[] bytes = null;
|
||||
GRAPH_TYPE def = null;
|
||||
try {
|
||||
bytes = IOUtils.toByteArray(inputStream); //Buffers internally
|
||||
def = parseGraphFrom(bytes);
|
||||
} catch (IOException e) {
|
||||
try (BufferedInputStream bis2 = new BufferedInputStream(new ByteArrayInputStream(bytes)); BufferedReader reader = new BufferedReader(new InputStreamReader(bis2))) {
|
||||
Message.Builder builder = getNewGraphBuilder();
|
||||
|
||||
StringBuilder str = new StringBuilder();
|
||||
String line = null;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
str.append(line);//.append("\n");
|
||||
}
|
||||
|
||||
TextFormat.getParser().merge(str.toString(), builder);
|
||||
def = (GRAPH_TYPE) builder.build();
|
||||
} catch (Exception e2) {
|
||||
e2.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
return def;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param graphFile
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public SameDiff importGraph(File graphFile) {
|
||||
return importGraph(graphFile, Collections.<String, OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>>emptyMap(), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SameDiff importGraph(File graphFile, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
|
||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter) {
|
||||
GRAPH_TYPE def = null;
|
||||
try (FileInputStream fis = new FileInputStream(graphFile)) {
|
||||
return importGraph(fis, opImportOverrides, opFilter);
|
||||
} catch (Exception e) {
|
||||
throw new ND4JIllegalStateException("Error encountered loading graph file: " + graphFile.getAbsolutePath(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, NODE_TYPE> nameIndexForGraph(GRAPH_TYPE graph) {
|
||||
List<NODE_TYPE> nodes = getNodeList(graph);
|
||||
Map<String, NODE_TYPE> ret = new HashMap<>();
|
||||
for (NODE_TYPE node : nodes) {
|
||||
ret.put(getName(node), node);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, NODE_TYPE> nodesByName(GRAPH_TYPE graph) {
|
||||
val nodeTypes = getNodeList(graph);
|
||||
Map<String, NODE_TYPE> ret = new LinkedHashMap<>();
|
||||
for (int i = 0; i < nodeTypes.size(); i++) {
|
||||
ret.put(getName(nodeTypes.get(i)), nodeTypes.get(i));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method converts given TF
|
||||
*
|
||||
* @param tfGraph
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public SameDiff importGraph(GRAPH_TYPE tfGraph) {
|
||||
return importGraph(tfGraph, Collections.<String, OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>>emptyMap(), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SameDiff importGraph(GRAPH_TYPE tfGraph, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
|
||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter) {
|
||||
|
||||
SameDiff diff = SameDiff.create();
|
||||
ImportState<GRAPH_TYPE, TENSOR_TYPE> importState = new ImportState<>();
|
||||
importState.setSameDiff(diff);
|
||||
importState.setGraph(tfGraph);
|
||||
|
||||
Map<String,TENSOR_TYPE> variablesForGraph = variablesForGraph(tfGraph);
|
||||
importState.setVariables(variablesForGraph);
|
||||
|
||||
|
||||
//Add each of the variables first - before importing ops
|
||||
Map<String, Boolean> stringNodes = new HashMap<>(); //Key: name of string variable. Value: if it's a constant
|
||||
for (Map.Entry<String, TENSOR_TYPE> entry : variablesForGraph.entrySet()) {
|
||||
if (shouldSkip((NODE_TYPE) entry.getValue())) { //TODO only works for TF
|
||||
//Skip some nodes, for example reduction indices (a lot of ND4J/SameDiff ops use int[] etc, not an INDArray/Variable)
|
||||
continue;
|
||||
}
|
||||
|
||||
//First: check if we're skipping the op entirely. If so: don't create the output variables for it.
|
||||
NODE_TYPE node = (NODE_TYPE) entry.getValue(); //TODO this only works for TF
|
||||
String opType = getOpType(node);
|
||||
String opName = getName(node);
|
||||
if(opFilter != null && opFilter.skipOp(node, importState.getSameDiff(), null, importState.getGraph() )){
|
||||
log.info("Skipping variables for op: {} (name: {})", opType, opName);
|
||||
continue;
|
||||
}
|
||||
|
||||
//Similarly, if an OpImportOverride is defined, don't create the variables now, as these might be the wrong type
|
||||
//For example, the OpImportOverride might replace the op with some placeholders
|
||||
// If we simply created them now, we'd create the wrong type (Array not placeholder)
|
||||
if(opImportOverrides != null && opImportOverrides.containsKey(opType)){
|
||||
log.info("Skipping variables for op due to presence of OpImportOverride: {} (name: {})", opType, opName);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
DataType dt = dataTypeForTensor(entry.getValue(), 0);
|
||||
INDArray arr = getNDArrayFromTensor(entry.getKey(), entry.getValue(), tfGraph);
|
||||
long[] shape = hasShape((NODE_TYPE) entry.getValue()) ? getShape((NODE_TYPE) entry.getValue()) : null; //TODO only works for TF
|
||||
|
||||
//Not all variables have datatypes available on import - we have to infer these at a later point
|
||||
// so we'll leave datatypes as null and infer them once all variables/ops have been imported
|
||||
if(dt == DataType.UNKNOWN)
|
||||
dt = null;
|
||||
|
||||
if (isPlaceHolder(entry.getValue())) {
|
||||
diff.placeHolder(entry.getKey(), dt, shape);
|
||||
} else if (isConstant(entry.getValue())) {
|
||||
Preconditions.checkNotNull(arr, "Array is null for placeholder variable %s", entry.getKey());
|
||||
diff.constant(entry.getKey(), arr);
|
||||
} else {
|
||||
//Could be variable, or could be array type (i.e., output of op/"activations")
|
||||
//TODO work out which!
|
||||
|
||||
SDVariable v;
|
||||
if(shape == null || ArrayUtil.contains(shape, 0)){
|
||||
//No shape, or 0 in shape -> probably not a variable...
|
||||
v = diff.var(entry.getKey(), VariableType.ARRAY, null, dt, (long[])null);
|
||||
} else {
|
||||
v = diff.var(entry.getKey(), dt, shape);
|
||||
}
|
||||
if (arr != null)
|
||||
diff.associateArrayWithVariable(arr, v);
|
||||
}
|
||||
|
||||
// NODE_TYPE node = (NODE_TYPE) entry.getValue(); //TODO this only works for TF
|
||||
List<String> controlDependencies = getControlDependencies(node);
|
||||
if (controlDependencies != null) {
|
||||
Variable v = diff.getVariables().get(entry.getKey());
|
||||
v.setControlDeps(controlDependencies);
|
||||
}
|
||||
}
|
||||
|
||||
//Map ops
|
||||
val tfNodesList = getNodeList(tfGraph);
|
||||
for (NODE_TYPE node : tfNodesList) {
|
||||
String opType = getOpType(node);
|
||||
OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> importOverride = null;
|
||||
if(opImportOverrides != null){
|
||||
importOverride = opImportOverrides.get(opType);
|
||||
}
|
||||
|
||||
if(opFilter != null && opFilter.skipOp(node, importState.getSameDiff(), null, null)){
|
||||
String opName = getName(node);
|
||||
log.info("Skipping op due to op filter: {}", opType, opName);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!opsToIgnore().contains(opType) || isOpIgnoreException(node)) {
|
||||
mapNodeType(node, importState, importOverride, opFilter);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
At this point, we have a few remaining things to do:
|
||||
1. Make sure all datatypes are set on all variables. TF doesn't have datatype info an all op outputs for some reason, so we have to infer in manually
|
||||
2. Make sure all op output variables have been created
|
||||
3. Make sure all SameDiffOp.outputsOfOp is set
|
||||
4. Make sure all Variable.outputOfOp is set
|
||||
5. Make sure all Variable.controlDepsForVar have been populated (reverse lookup of Variable.controlDeps)
|
||||
*/
|
||||
|
||||
//Make sure Variable.outputOfOp is set
|
||||
for(Variable v : diff.getVariables().values()){
|
||||
if(v.getVariable().isPlaceHolder() || v.getVariable().isConstant())
|
||||
continue;
|
||||
|
||||
//Expect variable names of output variables to be: opName, opName:1, opName:2, etc
|
||||
String n = v.getName();
|
||||
String opName = n;
|
||||
if(v.getName().matches(".*:\\d+")){
|
||||
//i.e., "something:2"
|
||||
int idx = n.lastIndexOf(':');
|
||||
opName = n.substring(0,idx);
|
||||
}
|
||||
|
||||
if(diff.getOps().containsKey(opName)) {
|
||||
//Variable is the output of an op
|
||||
v.setOutputOfOp(opName);
|
||||
|
||||
//Also double check variable type...
|
||||
if(v.getVariable().getVariableType() != VariableType.ARRAY)
|
||||
v.getVariable().setVariableType(VariableType.ARRAY);
|
||||
}
|
||||
}
|
||||
|
||||
//Initialize any missing output variables
|
||||
for (SameDiffOp op : diff.getOps().values()) {
|
||||
DifferentialFunction df = op.getOp();
|
||||
initOutputVariables(diff, df);
|
||||
}
|
||||
|
||||
//Make sure all Variable.controlDepsForVar have been populated (reverse lookup of Variable.controlDeps)
|
||||
//i.e., if control dependency x -> y exists, then:
|
||||
// (a) x.controlDepsForVar should contain "y"
|
||||
// (b) y.controlDeps should contain "x"
|
||||
//Need to do this before output datatype calculation, as these control dep info is used in sessions
|
||||
for(Map.Entry<String,Variable> e : diff.getVariables().entrySet()){
|
||||
Variable v = e.getValue();
|
||||
if(v.getControlDeps() != null){
|
||||
for(String s : v.getControlDeps()){
|
||||
Variable v2 = diff.getVariables().get(s);
|
||||
if(v2.getControlDepsForVar() == null)
|
||||
v2.setControlDepsForVar(new ArrayList<String>());
|
||||
if(!v2.getControlDepsForVar().contains(e.getKey())){
|
||||
//Control dep v2 -> v exists, so put v.name into v2.controlDepsForVar
|
||||
v2.getControlDepsForVar().add(e.getKey());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Same thing for op control dependencies...
|
||||
for(Map.Entry<String,SameDiffOp> e : diff.getOps().entrySet()){
|
||||
SameDiffOp op = e.getValue();
|
||||
if(op.getControlDeps() != null){
|
||||
for(String s : op.getControlDeps()){
|
||||
//Control dependency varS -> op exists
|
||||
Variable v = diff.getVariables().get(s);
|
||||
if(v.getControlDepsForOp() == null)
|
||||
v.setControlDepsForOp(new ArrayList<String>());
|
||||
if(!v.getControlDepsForOp().contains(e.getKey()))
|
||||
v.getControlDepsForOp().add(e.getKey());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//Infer variable datatypes to ensure all variables have datatypes...
|
||||
boolean anyUnknown = false;
|
||||
for(SDVariable v : diff.variables()){
|
||||
if(v.dataType() == null)
|
||||
anyUnknown = true;
|
||||
}
|
||||
if(anyUnknown){
|
||||
Map<String,DataType> dataTypes = diff.calculateOutputDataTypes();
|
||||
for(SDVariable v : diff.variables()){
|
||||
if(v.dataType() == null){
|
||||
v.setDataType(dataTypes.get(v.getVarName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Validate the graph structure
|
||||
validateGraphStructure(diff);
|
||||
|
||||
return diff;
|
||||
}
|
||||
|
||||
protected void initOutputVariables(SameDiff sd, DifferentialFunction df) {
|
||||
String[] outNames = sd.getOutputsForOp(df);
|
||||
SDVariable[] outVars;
|
||||
if (outNames == null) {
|
||||
outVars = sd.generateOutputVariableForOp(df, df.getOwnName() != null ? df.getOwnName() : df.opName(), true);
|
||||
outNames = new String[outVars.length];
|
||||
for (int i = 0; i < outVars.length; i++) {
|
||||
outNames[i] = outVars[i].getVarName();
|
||||
}
|
||||
sd.getOps().get(df.getOwnName()).setOutputsOfOp(Arrays.asList(outNames));
|
||||
}
|
||||
|
||||
for (String s : outNames) {
|
||||
sd.getVariables().get(s).setOutputOfOp(df.getOwnName());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean validTensorDataType(TENSOR_TYPE tensorType) {
|
||||
return dataTypeForTensor(tensorType, 0) != DataType.UNKNOWN;
|
||||
}
|
||||
|
||||
public void validateGraphStructure(SameDiff sameDiff) {
|
||||
//First: Check placeholders. When SDVariables are added with null shapes, these can be interpreted as a placeholder
|
||||
// but null shapes might simply mean shape isn't available during import right when the variable is added
|
||||
//Idea here: if a "placeholder" is the output of any function, it's not really a placeholder
|
||||
for (SDVariable v : sameDiff.variables()) {
|
||||
String name = v.getVarName();
|
||||
if (sameDiff.isPlaceHolder(name)) {
|
||||
String varOutputOf = sameDiff.getVariables().get(name).getOutputOfOp();
|
||||
}
|
||||
}
|
||||
|
||||
//Second: check that all op inputs actually exist in the graph
|
||||
for (SameDiffOp op : sameDiff.getOps().values()) {
|
||||
List<String> inputs = op.getInputsToOp();
|
||||
if (inputs == null)
|
||||
continue;
|
||||
|
||||
for (String s : inputs) {
|
||||
if (sameDiff.getVariable(s) == null) {
|
||||
throw new IllegalStateException("Import validation failed: op \"" + op.getName() + "\" of type " + op.getOp().getClass().getSimpleName()
|
||||
+ " has input \"" + s + "\" that does not have a corresponding variable in the graph");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,429 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.imports.graphmapper;
|
||||
|
||||
import org.nd4j.shade.protobuf.Message;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Map graph proto types to
|
||||
*
|
||||
* {@link SameDiff} instances
|
||||
* @param <GRAPH_TYPE> the proto type for the graph
|
||||
* @param <NODE_TYPE> the proto type for the node
|
||||
* @param <ATTR_TYPE> the proto type for the attribute
|
||||
* @param <TENSOR_TYPE> the proto type for the tensor
|
||||
*@author Adam Gibson
|
||||
*/
|
||||
public interface GraphMapper<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE,TENSOR_TYPE> {
|
||||
|
||||
/**
|
||||
* Import a graph as SameDiff from the given file
|
||||
* @param graphFile Input stream pointing to graph file to import
|
||||
* @return Imported graph
|
||||
*/
|
||||
SameDiff importGraph(InputStream graphFile);
|
||||
|
||||
SameDiff importGraph(InputStream graphFile, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
|
||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter);
|
||||
|
||||
/**
|
||||
* Import a graph as SameDiff from the given file
|
||||
* @param graphFile Graph file to import
|
||||
* @return Imported graph
|
||||
* @see #importGraph(File, Map)
|
||||
*/
|
||||
SameDiff importGraph(File graphFile);
|
||||
|
||||
/**
|
||||
* Import a graph as SameDiff from the given file, with optional op import overrides.<br>
|
||||
* The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops
|
||||
* that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions.
|
||||
*
|
||||
* @param graphFile Graph file to import
|
||||
* @param opImportOverrides May be null. If non-null: used to import the specified operations. Key is the name of the
|
||||
* operation to import, value is the object used to import it
|
||||
* @return Imported graph
|
||||
*/
|
||||
SameDiff importGraph(File graphFile, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
|
||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter);
|
||||
|
||||
/**
|
||||
* This method converts given graph type (in its native format) to SameDiff
|
||||
* @param graph Graph to import
|
||||
* @return Imported graph
|
||||
*/
|
||||
SameDiff importGraph(GRAPH_TYPE graph);
|
||||
|
||||
/**
|
||||
* This method converts given graph type (in its native format) to SameDiff<br>
|
||||
* The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops
|
||||
* that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions.
|
||||
* @param graph Graph to import
|
||||
* @return Imported graph
|
||||
*/
|
||||
SameDiff importGraph(GRAPH_TYPE graph, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
|
||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter);
|
||||
|
||||
|
||||
/**
|
||||
* Returns true if this node is a special case
|
||||
* (maybe because of name or other scenarios)
|
||||
* that should override {@link #opsToIgnore()}
|
||||
* in certain circumstances
|
||||
* @param node the node to check
|
||||
* @return true if this node is an exception false otherwise
|
||||
*/
|
||||
boolean isOpIgnoreException(NODE_TYPE node);
|
||||
|
||||
/**
|
||||
* Get the nodes sorted by n ame
|
||||
* from a given graph
|
||||
* @param graph the graph to get the nodes for
|
||||
* @return the map of the nodes by name
|
||||
* for a given graph
|
||||
*/
|
||||
Map<String,NODE_TYPE> nodesByName(GRAPH_TYPE graph);
|
||||
|
||||
/**
|
||||
* Get the target mapping key (usually based on the node name)
|
||||
* for the given function
|
||||
* @param function the function
|
||||
* @param node the node to derive the target mapping from
|
||||
* @return
|
||||
*/
|
||||
String getTargetMappingForOp(DifferentialFunction function, NODE_TYPE node);
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param on
|
||||
* @param node
|
||||
* @param graph
|
||||
* @param sameDiff
|
||||
* @param propertyMappings
|
||||
*/
|
||||
void mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappings);
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param name
|
||||
* @param on
|
||||
* @param node
|
||||
* @param graph
|
||||
* @param sameDiff
|
||||
* @param propertyMappingsForFunction
|
||||
*/
|
||||
void mapProperty(String name, DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction);
|
||||
|
||||
/**
|
||||
* Get the node from the graph
|
||||
* @param graph the graph to get the node from
|
||||
* @param name the name of the node to get from the graph
|
||||
* @return
|
||||
*/
|
||||
NODE_TYPE getNodeWithNameFromGraph(GRAPH_TYPE graph,String name);
|
||||
|
||||
/**
|
||||
* Returns true if the given node is a place holder
|
||||
* @param node the node to check
|
||||
* @return true if the node is a place holder or not
|
||||
*/
|
||||
boolean isPlaceHolderNode(TENSOR_TYPE node);
|
||||
|
||||
/**
|
||||
* Get the list of control dependencies for the current node (or null if none exist)
|
||||
*
|
||||
* @param node Node to get the control dependencies (if any) for
|
||||
* @return
|
||||
*/
|
||||
List<String> getControlDependencies(NODE_TYPE node);
|
||||
|
||||
/**
|
||||
* Dump a binary proto file representation as a
|
||||
* plain string in to the target text file
|
||||
* @param inputFile
|
||||
* @param outputFile
|
||||
*/
|
||||
void dumpBinaryProtoAsText(File inputFile,File outputFile);
|
||||
|
||||
|
||||
/**
|
||||
* Dump a binary proto file representation as a
|
||||
* plain string in to the target text file
|
||||
* @param inputFile
|
||||
* @param outputFile
|
||||
*/
|
||||
void dumpBinaryProtoAsText(InputStream inputFile,File outputFile);
|
||||
|
||||
|
||||
/**
|
||||
* Get the mapped op name
|
||||
* for a given op
|
||||
* relative to the type of node being mapped.
|
||||
* The input name should be based on a tensorflow
|
||||
* type or onnx type, not the nd4j name
|
||||
* @param name the tensorflow or onnx name
|
||||
* @return the function based on the values in
|
||||
* {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder}
|
||||
*/
|
||||
DifferentialFunction getMappedOp(String name);
|
||||
|
||||
|
||||
/**
|
||||
* Get the variables for the given graph
|
||||
* @param graphType the graph to get the variables for
|
||||
* @return a map of variable name to tensor
|
||||
*/
|
||||
Map<String,TENSOR_TYPE> variablesForGraph(GRAPH_TYPE graphType);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param name
|
||||
* @param node
|
||||
* @return
|
||||
*/
|
||||
String translateToSameDiffName(String name, NODE_TYPE node);
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param graph
|
||||
* @return
|
||||
*/
|
||||
Map<String,NODE_TYPE> nameIndexForGraph(GRAPH_TYPE graph);
|
||||
|
||||
/**
|
||||
* Returns an op type for the given input node
|
||||
* @param nodeType the node to use
|
||||
* @return the optype for the given node
|
||||
*/
|
||||
Op.Type opTypeForNode(NODE_TYPE nodeType);
|
||||
|
||||
/**
|
||||
* Returns a graph builder for initial definition and parsing.
|
||||
* @return
|
||||
*/
|
||||
Message.Builder getNewGraphBuilder();
|
||||
|
||||
/**
|
||||
* Parse a graph from an input stream
|
||||
* @param inputStream the input stream to load from
|
||||
* @return
|
||||
*/
|
||||
GRAPH_TYPE parseGraphFrom(byte[] inputStream) throws IOException;
|
||||
|
||||
/**
|
||||
* Parse a graph from an input stream
|
||||
* @param inputStream the input stream to load from
|
||||
* @return
|
||||
*/
|
||||
GRAPH_TYPE parseGraphFrom(InputStream inputStream) throws IOException;
|
||||
|
||||
|
||||
/**
|
||||
* Map a node in to the import state covering the {@link SameDiff} instance
|
||||
* @param tfNode the node to map
|
||||
* @param importState the current import state
|
||||
* @param opFilter Optional filter for skipping operations
|
||||
*/
|
||||
void mapNodeType(NODE_TYPE tfNode, ImportState<GRAPH_TYPE,TENSOR_TYPE> importState,
|
||||
OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opImportOverride,
|
||||
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter);
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param tensorType
|
||||
* @param outputNum
|
||||
* @return
|
||||
*/
|
||||
DataType dataTypeForTensor(TENSOR_TYPE tensorType, int outputNum);
|
||||
|
||||
boolean isStringType(TENSOR_TYPE tensor);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param nodeType
|
||||
* @param key
|
||||
* @return
|
||||
*/
|
||||
String getAttrValueFromNode(NODE_TYPE nodeType,String key);
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param attrType
|
||||
* @return
|
||||
*/
|
||||
long[] getShapeFromAttribute(ATTR_TYPE attrType);
|
||||
|
||||
/**
|
||||
* Returns true if the given node is a place holder type
|
||||
* (think a yet to be determined shape)_
|
||||
* @param nodeType
|
||||
* @return
|
||||
*/
|
||||
boolean isPlaceHolder(TENSOR_TYPE nodeType);
|
||||
|
||||
|
||||
/**
|
||||
* Returns true if the given node is a constant
|
||||
* @param nodeType
|
||||
* @return
|
||||
*/
|
||||
boolean isConstant(TENSOR_TYPE nodeType);
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* @param tensorName
|
||||
* @param tensorType
|
||||
* @param graph
|
||||
* @return
|
||||
*/
|
||||
INDArray getNDArrayFromTensor(String tensorName, TENSOR_TYPE tensorType, GRAPH_TYPE graph);
|
||||
|
||||
|
||||
/**
|
||||
* Get the shape for the given tensor type
|
||||
* @param tensorType
|
||||
* @return
|
||||
*/
|
||||
long[] getShapeFromTensor(TENSOR_TYPE tensorType);
|
||||
|
||||
|
||||
/**
|
||||
* Ops to ignore for mapping
|
||||
* @return
|
||||
*/
|
||||
Set<String> opsToIgnore();
|
||||
|
||||
/**
|
||||
* Get the input node for the given node
|
||||
* @param node the node
|
||||
* @param index hte index
|
||||
* @return
|
||||
*/
|
||||
String getInputFromNode(NODE_TYPE node, int index);
|
||||
|
||||
/**
|
||||
* Get the number of inputs for a node.
|
||||
* @param nodeType the node to get the number of inputs for
|
||||
* @return
|
||||
*/
|
||||
int numInputsFor(NODE_TYPE nodeType);
|
||||
|
||||
/**
|
||||
* Whether the data type for the tensor is valid
|
||||
* for creating an {@link INDArray}
|
||||
* @param tensorType the tensor proto to test
|
||||
* @return
|
||||
*/
|
||||
boolean validTensorDataType(TENSOR_TYPE tensorType);
|
||||
|
||||
|
||||
/**
|
||||
* Get the shape of the attribute value
|
||||
* @param attr the attribute value
|
||||
* @return the shape of the attribute if any or null
|
||||
*/
|
||||
long[] getShapeFromAttr(ATTR_TYPE attr);
|
||||
|
||||
/**
|
||||
* Get the attribute
|
||||
* map for given node
|
||||
* @param nodeType the node
|
||||
* @return the attribute map for the attribute
|
||||
*/
|
||||
Map<String,ATTR_TYPE> getAttrMap(NODE_TYPE nodeType);
|
||||
|
||||
/**
|
||||
* Get the name of the node
|
||||
* @param nodeType the node
|
||||
* to get the name for
|
||||
* @return
|
||||
*/
|
||||
String getName(NODE_TYPE nodeType);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param nodeType
|
||||
* @return
|
||||
*/
|
||||
boolean alreadySeen(NODE_TYPE nodeType);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param nodeType
|
||||
* @return
|
||||
*/
|
||||
boolean isVariableNode(NODE_TYPE nodeType);
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* @param opType
|
||||
* @return
|
||||
*/
|
||||
boolean shouldSkip(NODE_TYPE opType);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param nodeType
|
||||
* @return
|
||||
*/
|
||||
boolean hasShape(NODE_TYPE nodeType);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param nodeType
|
||||
* @return
|
||||
*/
|
||||
long[] getShape(NODE_TYPE nodeType);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param nodeType
|
||||
* @param graph
|
||||
* @return
|
||||
*/
|
||||
INDArray getArrayFrom(NODE_TYPE nodeType, GRAPH_TYPE graph);
|
||||
|
||||
|
||||
String getOpType(NODE_TYPE nodeType);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param graphType
|
||||
* @return
|
||||
*/
|
||||
List<NODE_TYPE> getNodeList(GRAPH_TYPE graphType);
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.imports.graphmapper;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
public class ImportState<GRAPH_TYPE,TENSOR_TYPE> {
|
||||
private SameDiff sameDiff;
|
||||
private GRAPH_TYPE graph;
|
||||
private Map<String,TENSOR_TYPE> variables;
|
||||
|
||||
|
||||
}
|
|
@ -1,652 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.imports.graphmapper.onnx;
|
||||
|
||||
import org.nd4j.shade.protobuf.ByteString;
|
||||
import org.nd4j.shade.protobuf.Message;
|
||||
import org.nd4j.shade.guava.primitives.Floats;
|
||||
import org.nd4j.shade.guava.primitives.Ints;
|
||||
import org.nd4j.shade.guava.primitives.Longs;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
||||
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.imports.graphmapper.BaseGraphMapper;
|
||||
import org.nd4j.imports.graphmapper.ImportState;
|
||||
import org.nd4j.imports.graphmapper.OpImportFilter;
|
||||
import org.nd4j.imports.graphmapper.OpImportOverride;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.io.*;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* A mapper for onnx graphs to
|
||||
* {@link org.nd4j.autodiff.samediff.SameDiff} instances.
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class OnnxGraphMapper extends BaseGraphMapper<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto, onnx.Onnx.TypeProto.Tensor> {
|
||||
private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper();
|
||||
|
||||
|
||||
public static OnnxGraphMapper getInstance() {
|
||||
return INSTANCE;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
|
||||
try {
|
||||
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(inputFile);
|
||||
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
|
||||
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
|
||||
bufferedWriter.write(node.toString() + "\n");
|
||||
}
|
||||
|
||||
bufferedWriter.flush();
|
||||
bufferedWriter.close();
|
||||
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Init a function's attributes
|
||||
* @param mappedTfName the onnx name to pick (sometimes ops have multiple names
|
||||
* @param on the function to map
|
||||
* @param attributesForNode the attributes for the node
|
||||
* @param node
|
||||
* @param graph
|
||||
*/
|
||||
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.NodeProto node, Onnx.GraphProto graph) {
|
||||
val properties = on.mappingsForFunction();
|
||||
val tfProperties = properties.get(mappedTfName);
|
||||
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
|
||||
val attributeAdapters = on.attributeAdaptersForFunction();
|
||||
for(val entry : tfProperties.entrySet()) {
|
||||
val tfAttrName = entry.getValue().getTfAttrName();
|
||||
val currentField = fields.get(entry.getKey());
|
||||
|
||||
AttributeAdapter adapter = null;
|
||||
if(tfAttrName != null) {
|
||||
if(currentField == null) {
|
||||
continue;
|
||||
}
|
||||
if(attributeAdapters != null && !attributeAdapters.isEmpty()) {
|
||||
val mappers = attributeAdapters.get(on.tensorflowName());
|
||||
val adapterFor = mappers.get(entry.getKey());
|
||||
adapter = adapterFor;
|
||||
}
|
||||
|
||||
|
||||
if(attributesForNode.containsKey(tfAttrName)) {
|
||||
val attr = attributesForNode.get(tfAttrName);
|
||||
switch (attr.getType()) {
|
||||
case STRING:
|
||||
val setString = attr.getS().toStringUtf8();
|
||||
if(adapter != null) {
|
||||
adapter.mapAttributeFor(setString,currentField,on);
|
||||
}
|
||||
else
|
||||
on.setValueFor(currentField,setString);
|
||||
break;
|
||||
case INT:
|
||||
val setInt = (int) attr.getI();
|
||||
if(adapter != null) {
|
||||
adapter.mapAttributeFor(setInt,currentField,on);
|
||||
}
|
||||
else
|
||||
on.setValueFor(currentField,setInt);
|
||||
break;
|
||||
case INTS:
|
||||
val setList = attr.getIntsList();
|
||||
if(!setList.isEmpty()) {
|
||||
val intList = Ints.toArray(setList);
|
||||
if(adapter != null) {
|
||||
adapter.mapAttributeFor(intList,currentField,on);
|
||||
}
|
||||
else
|
||||
on.setValueFor(currentField,intList);
|
||||
}
|
||||
break;
|
||||
case FLOATS:
|
||||
val floatsList = attr.getFloatsList();
|
||||
if(!floatsList.isEmpty()) {
|
||||
val floats = Floats.toArray(floatsList);
|
||||
if(adapter != null) {
|
||||
adapter.mapAttributeFor(floats,currentField,on);
|
||||
}
|
||||
|
||||
else
|
||||
on.setValueFor(currentField,floats);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case TENSOR:
|
||||
val tensorToGet = mapTensorProto(attr.getT());
|
||||
if(adapter != null) {
|
||||
adapter.mapAttributeFor(tensorToGet,currentField,on);
|
||||
}
|
||||
else
|
||||
on.setValueFor(currentField,tensorToGet);
|
||||
break;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isOpIgnoreException(Onnx.NodeProto node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getTargetMappingForOp(DifferentialFunction function, Onnx.NodeProto node) {
|
||||
return function.opName();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void mapProperty(String name, DifferentialFunction on, Onnx.NodeProto node, Onnx.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
|
||||
val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node));
|
||||
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
|
||||
/**
|
||||
* Map ints and the like. Need to figure out how attribute mapping should work.
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
val propsForFunction = on.propertiesForFunction();
|
||||
|
||||
if(mapping.getTfAttrName() == null) {
|
||||
int tfMappingIdx = mapping.getTfInputPosition();
|
||||
if(tfMappingIdx < 0)
|
||||
tfMappingIdx += node.getInputCount();
|
||||
|
||||
val input = node.getInput(tfMappingIdx);
|
||||
val inputNode = getInstance().getNodeWithNameFromGraph(graph,input);
|
||||
INDArray arr = sameDiff.getArrForVarName(input);
|
||||
val field = fields.get(mapping.getPropertyNames()[0]);
|
||||
val type = field.getType();
|
||||
if(type.equals(int[].class)) {
|
||||
try {
|
||||
field.set(arr.data().asInt(),on);
|
||||
} catch (IllegalAccessException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
else if(type.equals(int.class) || type.equals(long.class) || type.equals(Long.class) || type.equals(Integer.class)) {
|
||||
try {
|
||||
field.set(arr.getInt(0),on);
|
||||
} catch (IllegalAccessException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
else if(type.equals(float.class) || type.equals(double.class) || type.equals(Float.class) || type.equals(Double.class)) {
|
||||
try {
|
||||
field.set(arr.getDouble(0),on);
|
||||
} catch (IllegalAccessException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Figure out whether it's an int array
|
||||
* or a double array, or maybe a scalar.
|
||||
*/
|
||||
|
||||
}
|
||||
else {
|
||||
val tfMappingAttrName = mapping.getOnnxAttrName();
|
||||
val attr = getAttrMap(node).get(tfMappingAttrName);
|
||||
val type = attr.getType();
|
||||
val field = fields.get(mapping.getPropertyNames()[0]);
|
||||
|
||||
Object valueToSet = null;
|
||||
switch(type) {
|
||||
case INT:
|
||||
valueToSet = attr.getI();
|
||||
break;
|
||||
case FLOAT:
|
||||
valueToSet = attr.getF();
|
||||
break;
|
||||
case STRING:
|
||||
valueToSet = attr.getF();
|
||||
break;
|
||||
|
||||
}
|
||||
|
||||
try {
|
||||
field.set(valueToSet,on);
|
||||
} catch (IllegalAccessException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Onnx.NodeProto getNodeWithNameFromGraph(Onnx.GraphProto graph, String name) {
|
||||
for(int i = 0; i < graph.getNodeCount(); i++) {
|
||||
val node = graph.getNode(i);
|
||||
if(node.getName().equals(name))
|
||||
return node;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isPlaceHolderNode(Onnx.TypeProto.Tensor node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getControlDependencies(Onnx.NodeProto node) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
|
||||
try {
|
||||
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
|
||||
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
|
||||
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
|
||||
bufferedWriter.write(node.toString());
|
||||
}
|
||||
|
||||
bufferedWriter.flush();
|
||||
bufferedWriter.close();
|
||||
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param name the tensorflow or onnx name
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public DifferentialFunction getMappedOp(String name) {
|
||||
return DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(name);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public Map<String,onnx.Onnx.TypeProto.Tensor> variablesForGraph(Onnx.GraphProto graphProto) {
|
||||
/**
|
||||
* Need to figure out why
|
||||
* gpu_0/conv1_1 isn't present in VGG
|
||||
*/
|
||||
Map<String,onnx.Onnx.TypeProto.Tensor> ret = new HashMap<>();
|
||||
for(int i = 0; i < graphProto.getInputCount(); i++) {
|
||||
ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType());
|
||||
}
|
||||
|
||||
for(int i = 0; i < graphProto.getOutputCount(); i++) {
|
||||
ret.put(graphProto.getOutput(i).getName(),graphProto.getOutput(i).getType().getTensorType());
|
||||
}
|
||||
|
||||
for(int i = 0; i < graphProto.getNodeCount(); i++) {
|
||||
val node = graphProto.getNode(i);
|
||||
val name = node.getName().isEmpty() ? String.valueOf(i) : node.getName();
|
||||
//add -1 as place holder value representing the shape needs to be filled in
|
||||
if(!ret.containsKey(name)) {
|
||||
addDummyTensor(name,ret);
|
||||
}
|
||||
|
||||
for(int j = 0; j < node.getInputCount(); j++) {
|
||||
if(!ret.containsKey(node.getInput(j))) {
|
||||
addDummyTensor(node.getInput(j),ret);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
for(int j = 0; j < node.getOutputCount(); j++) {
|
||||
if(!ret.containsKey(node.getOutput(j))) {
|
||||
addDummyTensor(node.getOutput(j),ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String translateToSameDiffName(String name, Onnx.NodeProto node) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
protected void addDummyTensor(String name, Map<String, Onnx.TypeProto.Tensor> to) {
|
||||
Onnx.TensorShapeProto.Dimension dim = Onnx.TensorShapeProto.Dimension.
|
||||
newBuilder()
|
||||
.setDimValue(-1)
|
||||
.build();
|
||||
Onnx.TypeProto.Tensor typeProto = Onnx.TypeProto.Tensor.newBuilder()
|
||||
.setShape(
|
||||
Onnx.TensorShapeProto.newBuilder()
|
||||
.addDim(dim)
|
||||
.addDim(dim).build())
|
||||
.build();
|
||||
to.put(name,typeProto);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message.Builder getNewGraphBuilder() {
|
||||
return Onnx.GraphProto.newBuilder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Onnx.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
|
||||
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Onnx.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
|
||||
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mapNodeType(Onnx.NodeProto tfNode, ImportState<Onnx.GraphProto, Onnx.TypeProto.Tensor> importState,
|
||||
OpImportOverride<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opImportOverride,
|
||||
OpImportFilter<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opFilter) {
|
||||
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType());
|
||||
if(differentialFunction == null) {
|
||||
throw new NoOpNameFoundException("No op name found " + tfNode.getOpType());
|
||||
}
|
||||
|
||||
val diff = importState.getSameDiff();
|
||||
val idx = importState.getGraph().getNodeList().indexOf(tfNode);
|
||||
val name = !tfNode.getName().isEmpty() ? tfNode.getName() : String.valueOf(idx);
|
||||
try {
|
||||
val newInstance = differentialFunction.getClass().newInstance();
|
||||
val args = new SDVariable[tfNode.getInputCount()];
|
||||
|
||||
newInstance.setSameDiff(importState.getSameDiff());
|
||||
|
||||
newInstance.initFromOnnx(tfNode,diff,getAttrMap(tfNode),importState.getGraph());
|
||||
importState.getSameDiff().putOpForId(newInstance.getOwnName(),newInstance);
|
||||
//ensure we can track node name to function instance later.
|
||||
diff.setBaseNameForFunctionInstanceId(tfNode.getName(),newInstance);
|
||||
//diff.addVarNameForImport(tfNode.getName());
|
||||
}
|
||||
catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public DataType dataTypeForTensor(Onnx.TypeProto.Tensor tensorProto, int outputNum) {
|
||||
return nd4jTypeFromOnnxType(tensorProto.getElemType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isStringType(Onnx.TypeProto.Tensor tensor) {
|
||||
return tensor.getElemType() == Onnx.TensorProto.DataType.STRING;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Convert an onnx type to the proper nd4j type
|
||||
* @param dataType the data type to convert
|
||||
* @return the nd4j type for the onnx type
|
||||
*/
|
||||
public DataType nd4jTypeFromOnnxType(Onnx.TensorProto.DataType dataType) {
|
||||
switch (dataType) {
|
||||
case DOUBLE: return DataType.DOUBLE;
|
||||
case FLOAT: return DataType.FLOAT;
|
||||
case FLOAT16: return DataType.HALF;
|
||||
case INT32:
|
||||
case INT64: return DataType.INT;
|
||||
default: return DataType.UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getAttrValueFromNode(Onnx.NodeProto nodeProto, String key) {
|
||||
for(Onnx.AttributeProto attributeProto : nodeProto.getAttributeList()) {
|
||||
if(attributeProto.getName().equals(key)) {
|
||||
return attributeProto.getS().toString();
|
||||
}
|
||||
}
|
||||
|
||||
throw new ND4JIllegalStateException("No key found for " + key);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] getShapeFromAttribute(Onnx.AttributeProto attributeProto) {
|
||||
return Longs.toArray(attributeProto.getT().getDimsList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isPlaceHolder(Onnx.TypeProto.Tensor nodeType) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isConstant(Onnx.TypeProto.Tensor nodeType) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public INDArray getNDArrayFromTensor(String tensorName, Onnx.TypeProto.Tensor tensorProto, Onnx.GraphProto graph) {
|
||||
DataType type = dataTypeForTensor(tensorProto, 0);
|
||||
if(!tensorProto.isInitialized()) {
|
||||
throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
|
||||
}
|
||||
|
||||
Onnx.TensorProto tensor = null;
|
||||
for(int i = 0; i < graph.getInitializerCount(); i++) {
|
||||
val initializer = graph.getInitializer(i);
|
||||
if(initializer.getName().equals(tensorName)) {
|
||||
tensor = initializer;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(tensor == null)
|
||||
return null;
|
||||
|
||||
ByteString bytes = tensor.getRawData();
|
||||
ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
|
||||
ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
|
||||
directAlloc.put(byteBuffer);
|
||||
directAlloc.rewind();
|
||||
long[] shape = getShapeFromTensor(tensorProto);
|
||||
DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape));
|
||||
INDArray arr = Nd4j.create(buffer).reshape(shape);
|
||||
return arr;
|
||||
}
|
||||
|
||||
public INDArray mapTensorProto(Onnx.TensorProto tensor) {
|
||||
if(tensor == null)
|
||||
return null;
|
||||
|
||||
|
||||
DataType type = nd4jTypeFromOnnxType(tensor.getDataType());
|
||||
|
||||
ByteString bytes = tensor.getRawData();
|
||||
ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
|
||||
ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
|
||||
directAlloc.put(byteBuffer);
|
||||
directAlloc.rewind();
|
||||
long[] shape = getShapeFromTensor(tensor);
|
||||
DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape));
|
||||
INDArray arr = Nd4j.create(buffer).reshape(shape);
|
||||
return arr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] getShapeFromTensor(onnx.Onnx.TypeProto.Tensor tensorProto) {
|
||||
val ret = new long[Math.max(2,tensorProto.getShape().getDimCount())];
|
||||
int dimCount = tensorProto.getShape().getDimCount();
|
||||
if(dimCount >= 2)
|
||||
for(int i = 0; i < ret.length; i++) {
|
||||
ret[i] = (int) tensorProto.getShape().getDim(i).getDimValue();
|
||||
}
|
||||
else {
|
||||
ret[0] = 1;
|
||||
for(int i = 1; i < ret.length; i++) {
|
||||
ret[i] = (int) tensorProto.getShape().getDim(i - 1).getDimValue();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get the shape from a tensor proto.
|
||||
* Note that this is different from {@link #getShapeFromTensor(Onnx.TensorProto)}
|
||||
* @param tensorProto the tensor to get the shape from
|
||||
* @return
|
||||
*/
|
||||
public long[] getShapeFromTensor(Onnx.TensorProto tensorProto) {
|
||||
val ret = new long[Math.max(2,tensorProto.getDimsCount())];
|
||||
int dimCount = tensorProto.getDimsCount();
|
||||
if(dimCount >= 2)
|
||||
for(int i = 0; i < ret.length; i++) {
|
||||
ret[i] = (int) tensorProto.getDims(i);
|
||||
}
|
||||
else {
|
||||
ret[0] = 1;
|
||||
for(int i = 1; i < ret.length; i++) {
|
||||
ret[i] = (int) tensorProto.getDims(i - 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> opsToIgnore() {
|
||||
return Collections.emptySet();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String getInputFromNode(Onnx.NodeProto node, int index) {
|
||||
return node.getInput(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numInputsFor(Onnx.NodeProto nodeProto) {
|
||||
return nodeProto.getInputCount();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public long[] getShapeFromAttr(Onnx.AttributeProto attr) {
|
||||
return Longs.toArray(attr.getT().getDimsList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Onnx.AttributeProto> getAttrMap(Onnx.NodeProto nodeProto) {
|
||||
Map<String,Onnx.AttributeProto> proto = new HashMap<>();
|
||||
for(int i = 0; i < nodeProto.getAttributeCount(); i++) {
|
||||
Onnx.AttributeProto attributeProto = nodeProto.getAttribute(i);
|
||||
proto.put(attributeProto.getName(),attributeProto);
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName(Onnx.NodeProto nodeProto) {
|
||||
return nodeProto.getName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean alreadySeen(Onnx.NodeProto nodeProto) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isVariableNode(Onnx.NodeProto nodeProto) {
|
||||
return nodeProto.getOpType().contains("Var");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean shouldSkip(Onnx.NodeProto opType) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasShape(Onnx.NodeProto nodeProto) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] getShape(Onnx.NodeProto nodeProto) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getArrayFrom(Onnx.NodeProto nodeProto, Onnx.GraphProto graph) {
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOpType(Onnx.NodeProto nodeProto) {
|
||||
return nodeProto.getOpType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Onnx.NodeProto> getNodeList(Onnx.GraphProto graphProto) {
|
||||
return graphProto.getNodeList();
|
||||
}
|
||||
|
||||
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -226,22 +226,24 @@ public class TensorFlowImportValidator {
|
|||
}
|
||||
|
||||
public static TFImportStatus checkModelForImport(String path, InputStream is, boolean exceptionOnRead) throws IOException {
|
||||
TFGraphMapper m = TFGraphMapper.getInstance();
|
||||
|
||||
try {
|
||||
int opCount = 0;
|
||||
Set<String> opNames = new HashSet<>();
|
||||
|
||||
try(InputStream bis = new BufferedInputStream(is)) {
|
||||
GraphDef graphDef = m.parseGraphFrom(bis);
|
||||
List<NodeDef> nodes = m.getNodeList(graphDef);
|
||||
GraphDef graphDef = GraphDef.parseFrom(bis);
|
||||
List<NodeDef> nodes = new ArrayList<>(graphDef.getNodeCount());
|
||||
for( int i=0; i<graphDef.getNodeCount(); i++ ){
|
||||
nodes.add(graphDef.getNode(i));
|
||||
}
|
||||
|
||||
if(nodes.isEmpty()){
|
||||
throw new IllegalStateException("Error loading model for import - loaded graph def has no nodes (empty/corrupt file?): " + path);
|
||||
}
|
||||
|
||||
for (NodeDef nd : nodes) {
|
||||
if (m.isVariableNode(nd) || m.isPlaceHolderNode(nd))
|
||||
if (TFGraphMapper.isVariableNode(nd) || TFGraphMapper.isPlaceHolder(nd))
|
||||
continue;
|
||||
|
||||
String op = nd.getOp();
|
||||
|
|
|
@ -86,6 +86,7 @@ import java.io.*;
|
|||
import java.nio.IntBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
import static org.nd4j.linalg.factory.Nd4j.*;
|
||||
|
||||
|
@ -124,6 +125,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
protected transient JvmShapeInfo jvmShapeInfo;
|
||||
|
||||
|
||||
private static final AtomicLong arrayCounter = new AtomicLong(0);
|
||||
protected transient final long arrayId = arrayCounter.getAndIncrement();
|
||||
|
||||
|
||||
//Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over
|
||||
private static final int[][] tadFinalPermuteDimensions;
|
||||
|
@ -139,7 +143,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
}
|
||||
|
||||
public BaseNDArray() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -4916,6 +4919,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
|
||||
@Override
|
||||
public String toString(@NonNull NDArrayStrings options){
|
||||
if(wasClosed())
|
||||
return "<Closed NDArray, id=" + getId() + ", dtype=" + dataType() + ", shape=" + Arrays.toString(shape()) + ">";
|
||||
if (!isCompressed() && !preventUnpack)
|
||||
return options.format(this);
|
||||
else if (isCompressed() && compressDebug)
|
||||
|
@ -5600,4 +5605,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getId(){
|
||||
return arrayId;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2814,4 +2814,10 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
* @see org.nd4j.linalg.api.ndarray.BaseNDArray#toString(long, boolean, int)
|
||||
*/
|
||||
String toStringFull();
|
||||
|
||||
/**
|
||||
* A unique ID for the INDArray object instance. Does not account for views.
|
||||
* @return INDArray unique ID
|
||||
*/
|
||||
long getId();
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import onnx.Onnx;
|
|||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -200,47 +201,16 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
|||
|
||||
@Override
|
||||
public void setX(INDArray x) {
|
||||
if (x == null) {
|
||||
if (args() != null && args().length >= 1) {
|
||||
SDVariable firstArg = args()[0];
|
||||
if (firstArg.getArr() != null)
|
||||
this.x = firstArg.getArr();
|
||||
} else
|
||||
throw new ND4JIllegalStateException("Unable to set null array for x. Also unable to infer from differential function arguments");
|
||||
} else
|
||||
this.x = x;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setZ(INDArray z) {
|
||||
if (z == null) {
|
||||
SDVariable getResult = sameDiff.getVariable(zVertexId);
|
||||
if (getResult != null) {
|
||||
if (getResult.getArr() != null)
|
||||
this.z = getResult.getArr();
|
||||
else if(sameDiff.getShapeForVarName(getResult.getVarName()) != null) {
|
||||
val shape = sameDiff.getShapeForVarName(getResult.getVarName());
|
||||
sameDiff.setArrayForVariable(getResult.getVarName(),getResult.getWeightInitScheme().create(getResult.dataType(), shape));
|
||||
}
|
||||
else
|
||||
throw new ND4JIllegalStateException("Unable to set null array for z. Also unable to infer from differential function arguments");
|
||||
|
||||
} else
|
||||
throw new ND4JIllegalStateException("Unable to set null array for z. Also unable to infer from differential function arguments");
|
||||
} else
|
||||
this.z = z;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setY(INDArray y) {
|
||||
if (y == null) {
|
||||
if (args() != null && args().length > 1) {
|
||||
SDVariable firstArg = args()[1];
|
||||
if (firstArg.getArr() != null)
|
||||
this.y = firstArg.getArr();
|
||||
} else
|
||||
throw new ND4JIllegalStateException("Unable to set null array for y. Also unable to infer from differential function arguments");
|
||||
} else
|
||||
this.y = y;
|
||||
}
|
||||
|
||||
|
@ -265,6 +235,12 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
|||
return z;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getInputArgument(int index){
|
||||
Preconditions.checkState(index >= 0 && index < 2, "Input argument index must be 0 or 1, got %s", index);
|
||||
return index == 0 ? x : y;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SDVariable[] outputVariables(String baseName) {
|
||||
if(zVertexId == null) {
|
||||
|
@ -403,4 +379,11 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
|||
//Always 1 for legacy/base ops
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clearArrays(){
|
||||
x = null;
|
||||
y = null;
|
||||
z = null;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops;
|
||||
|
||||
import org.nd4j.shade.guava.primitives.Ints;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -24,21 +23,14 @@ import lombok.val;
|
|||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -71,10 +63,6 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
|
|||
this.keepDims = keepDims;
|
||||
this.xVertexId = i_v.getVarName();
|
||||
sameDiff.addArgsFor(new String[]{xVertexId},this);
|
||||
if(Shape.isPlaceholderShape(i_v.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v.getVarName());
|
||||
}
|
||||
|
||||
} else {
|
||||
throw new IllegalArgumentException("Input not null variable.");
|
||||
}
|
||||
|
@ -219,14 +207,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
|
|||
|
||||
@Override
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
if (!attributesForNode.containsKey("axes")) {
|
||||
this.dimensions = new int[] { Integer.MAX_VALUE };
|
||||
}
|
||||
else {
|
||||
val map = OnnxGraphMapper.getInstance().getAttrMap(node);
|
||||
val dims = Ints.toArray(map.get("axes").getIntsList());
|
||||
this.dimensions = dims;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -119,4 +119,9 @@ public interface CustomOp {
|
|||
* otherwise throws an {@link org.nd4j.linalg.exception.ND4JIllegalStateException}
|
||||
*/
|
||||
void assertValidForExecution();
|
||||
|
||||
/**
|
||||
* Clear the input and output INDArrays, if any are set
|
||||
*/
|
||||
void clearArrays();
|
||||
}
|
||||
|
|
|
@ -263,7 +263,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
@Override
|
||||
public INDArray[] outputArguments() {
|
||||
if (!outputArguments.isEmpty()) {
|
||||
return outputArguments.toArray(new INDArray[outputArguments.size()]);
|
||||
return outputArguments.toArray(new INDArray[0]);
|
||||
}
|
||||
return new INDArray[0];
|
||||
}
|
||||
|
@ -271,7 +271,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
@Override
|
||||
public INDArray[] inputArguments() {
|
||||
if (!inputArguments.isEmpty())
|
||||
return inputArguments.toArray(new INDArray[inputArguments.size()]);
|
||||
return inputArguments.toArray(new INDArray[0]);
|
||||
return new INDArray[0];
|
||||
|
||||
}
|
||||
|
@ -389,6 +389,13 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
}
|
||||
|
||||
public void setInputArgument(int index, INDArray input) {
|
||||
if(index >= inputArguments.size() ){
|
||||
List<INDArray> oldArgs = inputArguments;
|
||||
inputArguments = new ArrayList<>(index+1);
|
||||
inputArguments.addAll(oldArgs);
|
||||
while(inputArguments.size() <= index)
|
||||
inputArguments.add(null);
|
||||
}
|
||||
inputArguments.set(index, input);
|
||||
}
|
||||
|
||||
|
@ -400,12 +407,12 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
}
|
||||
|
||||
public void setOutputArgument(int index, INDArray output) {
|
||||
if(index == outputArguments.size()){
|
||||
//For example, setOutputArgument(0,arr) on empty list
|
||||
outputArguments.add(output);
|
||||
} else {
|
||||
outputArguments.set(index, output);
|
||||
while(index >= outputArguments.size()){
|
||||
//Resize list, in case we want to specify arrays not in order they are defined
|
||||
//For example, index 1 on empty list, then index 0
|
||||
outputArguments.add(null);
|
||||
}
|
||||
outputArguments.set(index, output);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -608,6 +615,12 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clearArrays(){
|
||||
inputArguments.clear();
|
||||
outputArguments.clear();
|
||||
}
|
||||
|
||||
protected static INDArray[] wrapOrNull(INDArray in){
|
||||
return in == null ? null : new INDArray[]{in};
|
||||
}
|
||||
|
|
|
@ -167,4 +167,9 @@ public interface Op {
|
|||
* @return the equivalent {@link CustomOp}
|
||||
*/
|
||||
CustomOp toCustomOp();
|
||||
|
||||
/**
|
||||
* Clear the input and output INDArrays, if any are set
|
||||
*/
|
||||
void clearArrays();
|
||||
}
|
||||
|
|
|
@ -25,6 +25,6 @@ public class AdjustContrastV2 extends BaseAdjustContrast {
|
|||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "AdjustContrast";
|
||||
return "AdjustContrastV2";
|
||||
}
|
||||
}
|
|
@ -245,4 +245,9 @@ public class ScatterUpdate implements CustomOp {
|
|||
public void assertValidForExecution() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clearArrays() {
|
||||
op.clearArrays();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,13 +39,18 @@ import java.util.*;
|
|||
@NoArgsConstructor
|
||||
public class BiasAdd extends DynamicCustomOp {
|
||||
|
||||
protected boolean nchw = true;
|
||||
|
||||
public BiasAdd(SameDiff sameDiff, SDVariable input, SDVariable bias) {
|
||||
public BiasAdd(SameDiff sameDiff, SDVariable input, SDVariable bias, boolean nchw) {
|
||||
super(null, sameDiff, new SDVariable[] {input, bias}, false);
|
||||
bArguments.clear();
|
||||
bArguments.add(nchw);
|
||||
}
|
||||
|
||||
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output){
|
||||
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){
|
||||
super(new INDArray[]{input, bias}, wrapOrNull(output));
|
||||
bArguments.clear();
|
||||
bArguments.add(nchw);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -56,7 +61,11 @@ public class BiasAdd extends DynamicCustomOp {
|
|||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
|
||||
|
||||
if(attributesForNode.containsKey("data_format")){
|
||||
nchw = "NCHW".equalsIgnoreCase(attributesForNode.get("data_format").getS().toStringUtf8());
|
||||
}
|
||||
bArguments.clear();
|
||||
bArguments.add(nchw);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1,402 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.controlflow;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.SameDiffConditional;
|
||||
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.CustomOp;
|
||||
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.util.HashUtil;
|
||||
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Equivalent to tensorflow's conditional op.
|
||||
* Runs one of 2 {@link SameDiff.SameDiffFunctionDefinition}
|
||||
* depending on a predicate {@link org.nd4j.autodiff.samediff.SameDiff.SameDiffConditional}
|
||||
*
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
@Slf4j
|
||||
public class If extends DifferentialFunction implements CustomOp {
|
||||
|
||||
@Getter
|
||||
protected SameDiff loopBodyExecution,predicateExecution,falseBodyExecution;
|
||||
|
||||
|
||||
@Getter
|
||||
protected SameDiffConditional predicate;
|
||||
@Getter
|
||||
protected SameDiffFunctionDefinition trueBody,falseBody;
|
||||
|
||||
@Getter
|
||||
protected String blockName,trueBodyName,falseBodyName;
|
||||
|
||||
@Getter
|
||||
protected SDVariable[] inputVars;
|
||||
|
||||
@Getter
|
||||
protected Boolean trueBodyExecuted = null;
|
||||
|
||||
@Getter
|
||||
protected SDVariable targetBoolean;
|
||||
|
||||
protected SDVariable dummyResult;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
protected SDVariable[] outputVars;
|
||||
|
||||
public If(If ifStatement) {
|
||||
this.sameDiff = ifStatement.sameDiff;
|
||||
this.outputVars = ifStatement.outputVars;
|
||||
this.falseBodyExecution = ifStatement.falseBodyExecution;
|
||||
this.trueBodyExecuted = ifStatement.trueBodyExecuted;
|
||||
this.falseBody = ifStatement.falseBody;
|
||||
this.trueBodyExecuted = ifStatement.trueBodyExecuted;
|
||||
this.dummyResult = ifStatement.dummyResult;
|
||||
this.inputVars = ifStatement.inputVars;
|
||||
this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme(), DataType.FLOAT, 1);
|
||||
if(sameDiff.getShapeForVarName(dummyResult.getVarName()) == null)
|
||||
sameDiff.putShapeForVarName(dummyResult.getVarName(),new long[]{1,1});
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Builder
|
||||
public If(String blockName,
|
||||
SameDiff parent,
|
||||
SDVariable[] inputVars,
|
||||
SameDiffFunctionDefinition conditionBody,
|
||||
SameDiffConditional predicate,
|
||||
SameDiffFunctionDefinition trueBody,
|
||||
SameDiffFunctionDefinition falseBody) {
|
||||
|
||||
this.sameDiff = parent;
|
||||
parent.putOpForId(getOwnName(),this);
|
||||
this.inputVars = inputVars;
|
||||
this.predicate = predicate;
|
||||
|
||||
parent.addArgsFor(inputVars,this);
|
||||
this.trueBody = trueBody;
|
||||
this.falseBody = falseBody;
|
||||
this.blockName = blockName;
|
||||
//need to add the op to the list of ops to be executed when running backwards
|
||||
this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1);
|
||||
parent.addOutgoingFor(new SDVariable[]{dummyResult},this);
|
||||
|
||||
//create a samediff sub graph for running just the execution
|
||||
//return a reference to the loop for referencing during actual execution
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
//store the reference to the result array and the same diff execution instance
|
||||
this.targetBoolean = predicate.eval(sameDiff,conditionBody, inputVars);
|
||||
this.predicateExecution = sameDiff;
|
||||
//store references to the loop body
|
||||
String trueBodyName = "true-body-" + UUID.randomUUID().toString();
|
||||
this.trueBodyName = trueBodyName;
|
||||
|
||||
String falseBodyName = "false-body-" + UUID.randomUUID().toString();
|
||||
this.falseBodyName = trueBodyName;
|
||||
|
||||
//running define function will setup a proper same diff instance
|
||||
this.loopBodyExecution = parent.defineFunction(trueBodyName,trueBody,inputVars);
|
||||
this.falseBodyExecution = parent.defineFunction(falseBodyName,falseBody,inputVars);
|
||||
parent.defineFunction(blockName,conditionBody,inputVars);
|
||||
parent.putSubFunction("predicate-eval-body-" + UUID.randomUUID().toString(),sameDiff);
|
||||
//get a reference to the actual loop body
|
||||
this.loopBodyExecution = parent.getFunction(trueBodyName);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Toggle whether the true body was executed
|
||||
* or the false body
|
||||
* @param trueBodyExecuted
|
||||
*/
|
||||
public void exectedTrueOrFalse(boolean trueBodyExecuted) {
|
||||
if(trueBodyExecuted)
|
||||
this.trueBodyExecuted = true;
|
||||
else
|
||||
this.trueBodyExecuted = false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public SDVariable[] outputVariables(String baseName) {
|
||||
return new SDVariable[]{dummyResult};
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
List<SDVariable> ret = new ArrayList<>();
|
||||
ret.addAll(Arrays.asList(new IfDerivative(this).outputVariables()));
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return opName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "if";
|
||||
}
|
||||
|
||||
@Override
|
||||
public long opHash() {
|
||||
return HashUtil.getLongHash(opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isInplaceCall() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] outputArguments() {
|
||||
return new INDArray[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] inputArguments() {
|
||||
return new INDArray[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] iArgs() {
|
||||
return new long[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] tArgs() {
|
||||
return new double[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean[] bArgs() {
|
||||
return new boolean[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIArgument(int... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIArgument(long... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addBArgument(boolean... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeIArgument(Integer arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean getBArgument(int index) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getIArgument(int index) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numIArguments() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addTArgument(double... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeTArgument(Double arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double getTArgument(int index) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numTArguments() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numBArguments() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addInputArgument(INDArray... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeInputArgument(INDArray arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getInputArgument(int index) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numInputArguments() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addOutputArgument(INDArray... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeOutputArgument(INDArray arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getOutputArgument(int index) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numOutputArguments() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Op.Type opType() {
|
||||
return Op.Type.CONDITIONAL;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
//cond is only part of while loops
|
||||
if(nodeDef.getName().contains("/cond/"))
|
||||
return;
|
||||
//usually should be a merge node for a conditional
|
||||
val ifNodes = TFGraphMapper.getInstance().nodesForIf(nodeDef,graph);
|
||||
|
||||
|
||||
val trueScopeGraphDefBuilder = GraphDef.newBuilder();
|
||||
for(val node : ifNodes.getTrueNodes()) {
|
||||
trueScopeGraphDefBuilder.addNode(node);
|
||||
}
|
||||
|
||||
|
||||
val trueScope = TFGraphMapper.getInstance().importGraph(trueScopeGraphDefBuilder.build());
|
||||
|
||||
|
||||
val falseScopeGraphDefBuilder = GraphDef.newBuilder();
|
||||
for(val node : ifNodes.getFalseNodes()) {
|
||||
falseScopeGraphDefBuilder.addNode(node);
|
||||
|
||||
}
|
||||
|
||||
val falseScope = TFGraphMapper.getInstance().importGraph(falseScopeGraphDefBuilder.build());
|
||||
|
||||
|
||||
val condScopeGraphDefBuilder = GraphDef.newBuilder();
|
||||
for(val node : ifNodes.getCondNodes()) {
|
||||
condScopeGraphDefBuilder.addNode(node);
|
||||
|
||||
}
|
||||
|
||||
|
||||
val condScope = TFGraphMapper.getInstance().importGraph(condScopeGraphDefBuilder.build());
|
||||
|
||||
|
||||
|
||||
initWith.putSubFunction(ifNodes.getTrueBodyScopeName(),trueScope);
|
||||
initWith.putSubFunction(ifNodes.getFalseBodyScopeName(),falseScope);
|
||||
initWith.putSubFunction(ifNodes.getConditionBodyScopeName(),condScope);
|
||||
|
||||
this.loopBodyExecution = trueScope;
|
||||
this.falseBodyExecution = falseScope;
|
||||
this.predicateExecution = condScope;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||
return Arrays.asList(LongShapeDescriptor.fromShape(new long[0], DataType.BOOL));
|
||||
}
|
||||
|
||||
@Override
|
||||
public CustomOpDescriptor getDescriptor() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void assertValidForExecution() {
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("This operation has no TF counterpart");
|
||||
}
|
||||
}
|
|
@ -1,93 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.controlflow;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.SameDiffConditional;
|
||||
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class IfDerivative extends If {
|
||||
|
||||
private If ifDelegate;
|
||||
|
||||
public IfDerivative(If ifBlock) {
|
||||
super(ifBlock);
|
||||
this.ifDelegate = ifBlock;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean getTrueBodyExecuted() {
|
||||
return ifDelegate.trueBodyExecuted;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public SameDiffFunctionDefinition getFalseBody() {
|
||||
return ifDelegate.falseBody;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SameDiff getFalseBodyExecution() {
|
||||
return ifDelegate.falseBodyExecution;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getBlockName() {
|
||||
return ifDelegate.blockName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getFalseBodyName() {
|
||||
return ifDelegate.falseBodyName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SameDiff getLoopBodyExecution() {
|
||||
return ifDelegate.loopBodyExecution;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SameDiffConditional getPredicate() {
|
||||
return ifDelegate.getPredicate();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SameDiff getPredicateExecution() {
|
||||
return ifDelegate.predicateExecution;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||
return super.calculateOutputShape();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "if_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> diff(List<SDVariable> i_v1) {
|
||||
throw new UnsupportedOperationException("Unable to take the derivative of the derivative for if");
|
||||
}
|
||||
}
|
|
@ -1,32 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.controlflow;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Builder
|
||||
@Data
|
||||
public class IfImportState {
|
||||
private List<NodeDef> condNodes;
|
||||
private List<NodeDef> trueNodes;
|
||||
private List<NodeDef> falseNodes;
|
||||
private String falseBodyScopeName,trueBodyScopeName,conditionBodyScopeName;
|
||||
}
|
|
@ -55,7 +55,7 @@ public class Select extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -1,660 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.controlflow;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.SameDiffConditional;
|
||||
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.CustomOp;
|
||||
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
/**
|
||||
* Equivalent to tensorflow's while loop
|
||||
* Takes in:
|
||||
* loopVars
|
||||
* loop body
|
||||
* condition
|
||||
*
|
||||
* runs loop till condition is false.
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
@Slf4j
|
||||
public class While extends DifferentialFunction implements CustomOp {
|
||||
private AtomicInteger startPosition;
|
||||
|
||||
|
||||
|
||||
@Getter
|
||||
protected SameDiff loopBodyExecution,predicateExecution;
|
||||
|
||||
|
||||
@Getter
|
||||
protected SameDiffConditional predicate;
|
||||
@Getter
|
||||
protected SameDiffFunctionDefinition trueBody;
|
||||
|
||||
@Getter
|
||||
protected String blockName,trueBodyName;
|
||||
|
||||
@Getter
|
||||
protected SDVariable[] inputVars;
|
||||
|
||||
|
||||
@Getter
|
||||
protected SDVariable targetBoolean;
|
||||
|
||||
protected SDVariable dummyResult;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
protected SDVariable[] outputVars;
|
||||
|
||||
@Getter
|
||||
protected int numLooped = 0;
|
||||
|
||||
/**
|
||||
* Mainly meant for tensorflow import.
|
||||
* This allows {@link #initFromTensorFlow(NodeDef, SameDiff, Map, GraphDef)}
|
||||
* to continue from a parent while loop
|
||||
* using the same graph
|
||||
* @param startPosition the start position for the import scan
|
||||
*/
|
||||
public While(AtomicInteger startPosition) {
|
||||
this.startPosition = startPosition;
|
||||
}
|
||||
|
||||
public While(While whileStatement) {
|
||||
this.sameDiff = whileStatement.sameDiff;
|
||||
this.outputVars = whileStatement.outputVars;
|
||||
this.loopBodyExecution = whileStatement.loopBodyExecution;
|
||||
this.numLooped = whileStatement.numLooped;
|
||||
this.dummyResult = whileStatement.dummyResult;
|
||||
this.predicate = whileStatement.predicate;
|
||||
this.predicateExecution = whileStatement.predicateExecution;
|
||||
this.inputVars = whileStatement.inputVars;
|
||||
this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Builder
|
||||
public While(String blockName,
|
||||
SameDiff parent,
|
||||
SDVariable[] inputVars,
|
||||
SameDiffConditional predicate,
|
||||
SameDiffFunctionDefinition condition,
|
||||
SameDiffFunctionDefinition trueBody) {
|
||||
init(blockName,parent,inputVars,predicate,condition,trueBody);
|
||||
}
|
||||
|
||||
|
||||
private void init(String blockName,
|
||||
SameDiff parent,
|
||||
SDVariable[] inputVars,
|
||||
SameDiffConditional predicate,
|
||||
SameDiffFunctionDefinition condition,
|
||||
SameDiffFunctionDefinition trueBody) {
|
||||
this.sameDiff = parent;
|
||||
this.inputVars = inputVars;
|
||||
this.predicate = predicate;
|
||||
this.trueBody = trueBody;
|
||||
this.blockName = blockName;
|
||||
this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1);
|
||||
parent.putOpForId(getOwnName(),this);
|
||||
|
||||
parent.addArgsFor(inputVars,this);
|
||||
parent.addOutgoingFor(new SDVariable[]{dummyResult},this);
|
||||
|
||||
|
||||
//create a samediff sub graph for running just the execution
|
||||
//return a reference to the loop for referencing during actual execution
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
//store the reference to the result array and the same diff execution instance
|
||||
this.targetBoolean = predicate.eval(sameDiff,condition, inputVars);
|
||||
this.predicateExecution = sameDiff;
|
||||
//store references to the loop body
|
||||
String trueBodyName = "true-body-" + UUID.randomUUID().toString();
|
||||
this.trueBodyName = trueBodyName;
|
||||
//running define function will setup a proper same diff instance
|
||||
parent.defineFunction(trueBodyName,trueBody,inputVars);
|
||||
parent.defineFunction(blockName,condition,inputVars);
|
||||
parent.putSubFunction("predicate-eval-body",sameDiff);
|
||||
//get a reference to the actual loop body
|
||||
this.loopBodyExecution = parent.getFunction(trueBodyName);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public SDVariable[] outputVariables(String baseName) {
|
||||
return new SDVariable[]{dummyResult};
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
List<SDVariable> ret = new ArrayList<>();
|
||||
ret.addAll(Arrays.asList(new WhileDerivative(this).outputVariables()));
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Increments the loop counter.
|
||||
* This should be called when the loop
|
||||
* actually executes.
|
||||
*/
|
||||
public void incrementLoopCounter() {
|
||||
numLooped++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
doImport(nodeDef,initWith,attributesForNode,graph,new LinkedHashSet<String>(),new AtomicInteger(0));
|
||||
}
|
||||
|
||||
|
||||
private void doImport(NodeDef nodeDef,SameDiff initWith,Map<String,AttrValue> attributesForNode,GraphDef graph,Set<String> skipSet,AtomicInteger currIndex) {
|
||||
val uniqueId = java.util.UUID.randomUUID().toString();
|
||||
skipSet.add(nodeDef.getName());
|
||||
val scopeCondition = SameDiff.create();
|
||||
val scopeLoop = SameDiff.create();
|
||||
initWith.putSubFunction("condition-" + uniqueId,scopeCondition);
|
||||
initWith.putSubFunction("loopbody-" + uniqueId,scopeLoop);
|
||||
this.loopBodyExecution = scopeLoop;
|
||||
this.predicateExecution = scopeCondition;
|
||||
this.startPosition = currIndex;
|
||||
|
||||
log.info("Adding 2 new scopes for WHILE {}");
|
||||
|
||||
|
||||
val nodes = graph.getNodeList();
|
||||
|
||||
/**
|
||||
* Plan is simple:
|
||||
* 1) we read all declarations of variables used within loop
|
||||
* 2) we set up conditional scope
|
||||
* 3) we set up body scope
|
||||
* 4) ???
|
||||
* 5) PROFIT!
|
||||
*/
|
||||
|
||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
||||
val tfNode = nodes.get(currIndex.get());
|
||||
|
||||
if (!tfNode.getOp().equalsIgnoreCase("enter")) {
|
||||
//skipSet.add(tfNode.getName());
|
||||
break;
|
||||
}
|
||||
|
||||
// if (skipSet.contains(tfNode.getName()))
|
||||
// continue;
|
||||
|
||||
skipSet.add(tfNode.getName());
|
||||
|
||||
val vars = new SDVariable[tfNode.getInputCount()];
|
||||
for (int e = 0; e < tfNode.getInputCount(); e++) {
|
||||
val input = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(e));
|
||||
vars[e] = initWith.getVariable(input) == null ? initWith.var(input, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(input);
|
||||
scopeCondition.var(vars[e]);
|
||||
scopeLoop.var(vars[e]);
|
||||
}
|
||||
|
||||
this.inputVars = vars;
|
||||
}
|
||||
|
||||
|
||||
// now we're skipping Merge step, since we've already captured variables at Enter step
|
||||
int mergedCnt = 0;
|
||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
||||
val tfNode = nodes.get(currIndex.get());
|
||||
|
||||
if (!tfNode.getOp().equalsIgnoreCase("merge")) {
|
||||
scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), (LongShapeDescriptor) null,new ZeroInitScheme());
|
||||
break;
|
||||
}
|
||||
|
||||
skipSet.add(tfNode.getName());
|
||||
val var = scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), (LongShapeDescriptor)null,new ZeroInitScheme());
|
||||
scopeCondition.var(var);
|
||||
initWith.var(var);
|
||||
mergedCnt++;
|
||||
}
|
||||
|
||||
|
||||
// now, we're adding conditional scope
|
||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
||||
val tfNode = nodes.get(currIndex.get());
|
||||
|
||||
// we're parsing up to condition
|
||||
if (tfNode.getOp().equalsIgnoreCase("LoopCond")) {
|
||||
skipSet.add(tfNode.getName());
|
||||
currIndex.incrementAndGet();
|
||||
break;
|
||||
}
|
||||
|
||||
boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
|
||||
boolean isVar = tfNode.getOp().startsWith("VariableV");
|
||||
boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
|
||||
|
||||
|
||||
if (isConst || isVar || isPlaceholder) {
|
||||
val var = scopeCondition.var(tfNode.getName(), (LongShapeDescriptor) null,new ZeroInitScheme());
|
||||
scopeLoop.var(var);
|
||||
initWith.var(var);
|
||||
log.info("Adding condition var [{}]", var.getVarName());
|
||||
|
||||
}
|
||||
else if(!skipSet.contains(tfNode.getName())) {
|
||||
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
|
||||
func.initFromTensorFlow(tfNode,scopeCondition,nodeDef.getAttrMap(),graph);
|
||||
func.setSameDiff(scopeLoop);
|
||||
|
||||
}
|
||||
|
||||
skipSet.add(tfNode.getName());
|
||||
}
|
||||
|
||||
|
||||
|
||||
// time to skip some Switch calls
|
||||
int switchCnt = 0;
|
||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
||||
val tfNode = nodes.get(currIndex.get());
|
||||
|
||||
// we're parsing up to condition
|
||||
if (!tfNode.getOp().equalsIgnoreCase("Switch"))
|
||||
break;
|
||||
|
||||
switchCnt++;
|
||||
skipSet.add(tfNode.getName());
|
||||
}
|
||||
|
||||
// now we're parsing Identity step
|
||||
int identityCnt = 0;
|
||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
||||
val tfNode = nodes.get(currIndex.get());
|
||||
|
||||
|
||||
if (!tfNode.getOp().equalsIgnoreCase("Identity")) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
|
||||
func.initFromTensorFlow(tfNode,initWith,nodeDef.getAttrMap(),graph);
|
||||
func.setSameDiff(scopeLoop);
|
||||
|
||||
|
||||
val variables = new SDVariable[tfNode.getInputCount()];
|
||||
for(int i = 0; i < tfNode.getInputCount(); i++) {
|
||||
val testVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
|
||||
if(testVar == null) {
|
||||
variables[i] = initWith.var(tfNode.getInput(i), (LongShapeDescriptor) null,new ZeroInitScheme());
|
||||
scopeCondition.var(variables[i]);
|
||||
scopeLoop.var(variables[i]);
|
||||
continue;
|
||||
}
|
||||
else {
|
||||
|
||||
variables[i] = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
|
||||
scopeCondition.var(variables[i]);
|
||||
scopeLoop.var(variables[i]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
scopeLoop.addArgsFor(variables,func);
|
||||
skipSet.add(tfNode.getName());
|
||||
}
|
||||
|
||||
|
||||
// parsing body scope
|
||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
||||
val tfNode = nodes.get(currIndex.get());
|
||||
|
||||
if (skipSet.contains(tfNode.getName())) {
|
||||
log.info("Skipping: {}", tfNode.getName());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tfNode.getOp().equalsIgnoreCase("NextIteration")) {
|
||||
// skipSet.add(tfNode.getName());
|
||||
break;
|
||||
}
|
||||
|
||||
if (skipSet.contains(tfNode.getName())) {
|
||||
log.info("Skipping: {}", tfNode.getName());
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
|
||||
boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
|
||||
boolean isVar = tfNode.getOp().startsWith("VariableV");
|
||||
boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
|
||||
|
||||
|
||||
if (isConst || isVar || isPlaceholder) {
|
||||
val var = scopeLoop.var(tfNode.getName(), (LongShapeDescriptor) null,new ZeroInitScheme());
|
||||
log.info("Adding body var [{}]",var.getVarName());
|
||||
|
||||
} else {
|
||||
log.info("starting on [{}]: {}", tfNode.getName(), tfNode.getOp());
|
||||
|
||||
if (tfNode.getOp().equalsIgnoreCase("enter")) {
|
||||
log.info("NEW LOOP ----------------------------------------");
|
||||
val func = new While(currIndex);
|
||||
func.doImport(nodeDef,initWith,attributesForNode,graph,skipSet,currIndex);
|
||||
func.setSameDiff(initWith);
|
||||
log.info("END LOOP ----------------------------------------");
|
||||
} else {
|
||||
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
|
||||
|
||||
func.initFromTensorFlow(tfNode,initWith,nodeDef.getAttrMap(),graph);
|
||||
|
||||
|
||||
func.setSameDiff(scopeCondition);
|
||||
|
||||
val variables = new SDVariable[tfNode.getInputCount()];
|
||||
for(int i = 0; i < tfNode.getInputCount(); i++) {
|
||||
val name = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i));
|
||||
variables[i] = scopeCondition.getVariable(name);
|
||||
if(variables[i] == null) {
|
||||
if(scopeLoop.getVariable(name) == null)
|
||||
variables[i] = scopeCondition.var(initWith.getVariable(name));
|
||||
else if(scopeLoop.getVariable(name) != null)
|
||||
variables[i] = scopeLoop.getVariable(name);
|
||||
else
|
||||
variables[i] = scopeLoop.var(name, Nd4j.scalar(1.0));
|
||||
}
|
||||
}
|
||||
|
||||
scopeLoop.addArgsFor(variables,func);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
skipSet.add(tfNode.getName());
|
||||
}
|
||||
|
||||
|
||||
val returnInputs = new ArrayList<SDVariable>();
|
||||
val returnOutputs = new ArrayList<SDVariable>();
|
||||
|
||||
// mapping NextIterations, to Return op
|
||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
||||
val tfNode = nodes.get(currIndex.get());
|
||||
|
||||
if (!tfNode.getOp().equalsIgnoreCase("NextIteration"))
|
||||
break;
|
||||
|
||||
skipSet.add(tfNode.getName());
|
||||
|
||||
val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
|
||||
val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(inputName) ;
|
||||
returnInputs.add(input);
|
||||
}
|
||||
|
||||
|
||||
this.outputVars = returnOutputs.toArray(new SDVariable[returnOutputs.size()]);
|
||||
this.inputVars = returnInputs.toArray(new SDVariable[returnInputs.size()]);
|
||||
initWith.addArgsFor(inputVars,this);
|
||||
initWith.addOutgoingFor(outputVars,this);
|
||||
|
||||
// we should also map While/Exit to libnd4j while
|
||||
int exitCnt = 0;
|
||||
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
|
||||
val tfNode = nodes.get(currIndex.get());
|
||||
|
||||
if (!tfNode.getOp().equalsIgnoreCase("Exit")) {
|
||||
//skipSet.add(tfNode.getName());
|
||||
break;
|
||||
}
|
||||
|
||||
skipSet.add(tfNode.getName());
|
||||
val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
|
||||
val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(inputName) ;
|
||||
}
|
||||
|
||||
|
||||
//the output of the condition should always be a singular scalar
|
||||
//this is a safe assumption
|
||||
val conditionVars = scopeCondition.ops();
|
||||
if(conditionVars.length < 1) {
|
||||
throw new ND4JIllegalArgumentException("No functions found!");
|
||||
}
|
||||
this.targetBoolean = conditionVars[conditionVars.length - 1].outputVariables()[0];
|
||||
|
||||
log.info("-------------------------------------------");
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return opName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "while";
|
||||
}
|
||||
|
||||
@Override
|
||||
public long opHash() {
|
||||
return opName().hashCode();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isInplaceCall() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] outputArguments() {
|
||||
return new INDArray[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] inputArguments() {
|
||||
return new INDArray[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] iArgs() {
|
||||
return new long[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] tArgs() {
|
||||
return new double[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIArgument(int... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIArgument(long... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeIArgument(Integer arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getIArgument(int index) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numIArguments() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addTArgument(double... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeTArgument(Double arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double getTArgument(int index) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numTArguments() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numBArguments() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addInputArgument(INDArray... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeInputArgument(INDArray arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean[] bArgs() {
|
||||
return new boolean[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addBArgument(boolean... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean getBArgument(int index) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getInputArgument(int index) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numInputArguments() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addOutputArgument(INDArray... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeOutputArgument(INDArray arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getOutputArgument(int index) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numOutputArguments() {
|
||||
return 0;
|
||||
}
|
||||
@Override
|
||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||
List<LongShapeDescriptor> ret = new ArrayList<>();
|
||||
for(SDVariable var : args()) {
|
||||
ret.add(sameDiff.getShapeDescriptorForVarName(var.getVarName()));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CustomOpDescriptor getDescriptor() {
|
||||
return CustomOpDescriptor.builder().build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void assertValidForExecution() {
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No *singular (eg: use tensorflowNames() found for this op " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] tensorflowNames() {
|
||||
throw new NoOpNameFoundException("This operation has no TF counterpart");
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Op.Type opType() {
|
||||
return Op.Type.LOOP;
|
||||
}
|
||||
}
|
|
@ -1,96 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.controlflow;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.SameDiffConditional;
|
||||
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
|
||||
/**
|
||||
* While loop derivative
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class WhileDerivative extends While {
|
||||
private While delegate;
|
||||
|
||||
public WhileDerivative(While delegate) {
|
||||
super(delegate);
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public SameDiffFunctionDefinition getTrueBody() {
|
||||
return delegate.trueBody;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getTrueBodyName() {
|
||||
return delegate.getTrueBodyName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SameDiffConditional getPredicate() {
|
||||
return delegate.getPredicate();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SameDiff getPredicateExecution() {
|
||||
return delegate.getPredicateExecution();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SDVariable[] getInputVars() {
|
||||
return delegate.getInputVars();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getBlockName() {
|
||||
return delegate.getBlockName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SameDiff getLoopBodyExecution() {
|
||||
return delegate.getLoopBodyExecution();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumLooped() {
|
||||
return delegate.getNumLooped();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "while_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Op.Type opType() {
|
||||
return Op.Type.CONDITIONAL;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No tensorflow name for while backprop");
|
||||
}
|
||||
}
|
|
@ -55,7 +55,7 @@ public abstract class BaseCompatOp extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -32,9 +32,11 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
|
||||
public class LoopCond extends BaseCompatOp {
|
||||
public static final String OP_NAME = "loop_cond";
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "loop_cond";
|
||||
return OP_NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -74,8 +74,6 @@ public class CropAndResize extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
String method = attributesForNode.get("method").getS().toStringUtf8();
|
||||
if(method.equalsIgnoreCase("nearest")){
|
||||
this.method = Method.NEAREST;
|
||||
|
|
|
@ -120,4 +120,10 @@ public class ExtractImagePatches extends DynamicCustomOp {
|
|||
//TF includes redundant leading and training 1s for kSizes, strides, rates (positions 0/3)
|
||||
return new int[]{(int)ilist.getI(1), (int)ilist.getI(2)};
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -74,7 +74,7 @@ public class ResizeBilinear extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
this.alignCorners = attributesForNode.get("align_corners").getB();
|
||||
addArgs();
|
||||
|
|
|
@ -50,7 +50,7 @@ public class ResizeNearestNeighbor extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -26,8 +26,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
|||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -41,7 +39,6 @@ import org.tensorflow.framework.AttrValue;
|
|||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.*;
|
||||
|
||||
|
||||
|
@ -106,7 +103,7 @@ public class BatchNorm extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
//Switch order: TF uses [input, gamma, beta, mean, variance]; libnd4j expects [input, mean, variance, gamma, beta]
|
||||
SameDiffOp op = initWith.getOps().get(this.getOwnName());
|
||||
List<String> list = op.getInputsToOp();
|
||||
|
@ -140,8 +137,7 @@ public class BatchNorm extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||
addArgs();
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -21,33 +21,20 @@ import lombok.Getter;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
||||
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter;
|
||||
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.*;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
/**
|
||||
|
|
|
@ -31,7 +31,6 @@ import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
|||
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.*;
|
||||
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -122,7 +121,7 @@ public class Conv2D extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
addArgs();
|
||||
}
|
||||
|
||||
|
@ -138,8 +137,7 @@ public class Conv2D extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||
addArgs();
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -251,7 +251,7 @@ public class Conv3D extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
addArgs();
|
||||
}
|
||||
|
||||
|
|
|
@ -198,7 +198,7 @@ public class DeConv2D extends DynamicCustomOp {
|
|||
val args = args();
|
||||
INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr();
|
||||
if (arr == null) {
|
||||
arr = TFGraphMapper.getInstance().getNDArrayFromTensor(nodeDef.getInput(0), nodeDef, graph);
|
||||
arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
|
||||
// TODO: arguable. it might be easier to permute weights once
|
||||
//arr = (arr.permute(3, 2, 0, 1).dup('c'));
|
||||
val varForOp = initWith.getVariable(args[1].getVarName());
|
||||
|
|
|
@ -214,7 +214,7 @@ public class DeConv2DTF extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
addArgs();
|
||||
}
|
||||
|
||||
|
@ -240,9 +240,9 @@ public class DeConv2DTF extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ //inShape, weights, input
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
return Collections.singletonList(inputDataTypes.get(2));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -160,7 +160,7 @@ public class DeConv3D extends DynamicCustomOp {
|
|||
val args = args();
|
||||
INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr();
|
||||
if (arr == null) {
|
||||
arr = TFGraphMapper.getInstance().getNDArrayFromTensor(nodeDef.getInput(0), nodeDef, graph);
|
||||
arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
|
||||
val varForOp = initWith.getVariable(args[1].getVarName());
|
||||
if (arr != null)
|
||||
initWith.associateArrayWithVariable(arr, varForOp);
|
||||
|
|
|
@ -77,7 +77,7 @@ public class DepthToSpace extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
boolean isNHWC = dataFormat.equals("NHWC");
|
||||
addIArgument(blockSize, isNHWC ? 1 : 0);
|
||||
}
|
||||
|
|
|
@ -29,14 +29,15 @@ import org.nd4j.imports.NoOpNameFoundException;
|
|||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
||||
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.*;
|
||||
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.NDArrayShapeAdapter;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -136,7 +137,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
addArgs();
|
||||
|
||||
/*
|
||||
|
@ -162,8 +163,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||
addArgs();
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ public class SpaceToDepth extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
boolean isNHWC = dataFormat == null ? true : dataFormat.equals("NHWC");
|
||||
addIArgument(blockSize, isNHWC ? 1 : 0);
|
||||
}
|
||||
|
|
|
@ -64,7 +64,7 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
addArgs();
|
||||
}
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
//Switch order: TF uses [logits, labels]; libnd4j expects [labels, logits]
|
||||
SameDiffOp op = initWith.getOps().get(this.getOwnName());
|
||||
|
|
|
@ -64,7 +64,7 @@ public class Moments extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
addArgs();
|
||||
}
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ public class NormalizeMoments extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
addArgs();
|
||||
}
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ public class ScatterAdd extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
if (nodeDef.containsAttr("use_locking")) {
|
||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||
|
|
|
@ -86,7 +86,7 @@ public class ScatterDiv extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
if (nodeDef.containsAttr("use_locking")) {
|
||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||
|
|
|
@ -60,7 +60,7 @@ public class ScatterMax extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
if (nodeDef.containsAttr("use_locking")) {
|
||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||
|
|
|
@ -60,7 +60,7 @@ public class ScatterMin extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
if (nodeDef.containsAttr("use_locking")) {
|
||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||
|
|
|
@ -62,7 +62,7 @@ public class ScatterMul extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
if (nodeDef.containsAttr("use_locking")) {
|
||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||
|
|
|
@ -67,7 +67,7 @@ public class ScatterNd extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
if (nodeDef.containsAttr("use_locking")) {
|
||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||
|
@ -80,8 +80,8 @@ public class ScatterNd extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ //Indices, updates, shape
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(1));
|
||||
}
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ public class ScatterNdAdd extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
if (nodeDef.containsAttr("use_locking")) {
|
||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||
|
|
|
@ -66,7 +66,7 @@ public class ScatterNdSub extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
if (nodeDef.containsAttr("use_locking")) {
|
||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||
|
|
|
@ -66,7 +66,7 @@ public class ScatterNdUpdate extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
if (nodeDef.containsAttr("use_locking")) {
|
||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||
|
|
|
@ -79,7 +79,7 @@ public class ScatterSub extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
if (nodeDef.containsAttr("use_locking")) {
|
||||
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue