Hi,
I am trying to import a TF graph that looks like this :
Format: name:op
input: Placeholder (1, 224, 224, 3)
⋮
⋮
reshape: Reshape (1, 1001)
bottleneck: PlaceholderWithDefault (1, 1001)
⋮
⋮
output: Softmax (1, 11)
I am following this tutorial: https://docs.tvm.ai/tutorials/frontend/from_tensorflow.html#import-model
with tf.gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Add shapes to the graph.
with tf.Session() as sess:
graph_def = tf_testing.AddShapesToGraphDef(sess, OUTPUT_NODE)
mod, params = relay.frontend.from_tensorflow(graph_def)
My problem is that I cannot import the full graph.
If I use OUTPUT_NODE = 'reshape'
, I get a Relay graph from the input node up to the reshape node.
If I use OUTPUT_NODE = 'output'
, I get a Relay graph from the bottleneck node up to the output node.
How can I import the full graph from the input node to the output node?