I am trying to import a TF graph that looks like this :
Format: name:op input: Placeholder (1, 224, 224, 3) ⋮ ⋮ reshape: Reshape (1, 1001) bottleneck: PlaceholderWithDefault (1, 1001) ⋮ ⋮ output: Softmax (1, 11)
I am following this tutorial: https://docs.tvm.ai/tutorials/frontend/from_tensorflow.html#import-model
with tf.gfile.FastGFile(model_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name='') # Call the utility to import the graph definition into default graph. graph_def = tf_testing.ProcessGraphDefParam(graph_def) # Add shapes to the graph. with tf.Session() as sess: graph_def = tf_testing.AddShapesToGraphDef(sess, OUTPUT_NODE) mod, params = relay.frontend.from_tensorflow(graph_def)
My problem is that I cannot import the full graph.
If I use
OUTPUT_NODE = 'reshape', I get a Relay graph from the input node up to the reshape node.
If I use
OUTPUT_NODE = 'output', I get a Relay graph from the bottleneck node up to the output node.
How can I import the full graph from the input node to the output node?