[Solved][Tensorflow][Relay] Import issue with graph having two Placeholders


#1

Hi,

I am trying to import a TF graph that looks like this :

Format: name:op

input: Placeholder (1, 224, 224, 3)
⋮
⋮
reshape: Reshape (1, 1001)
bottleneck: PlaceholderWithDefault (1, 1001)
⋮
⋮
output: Softmax (1, 11)

I am following this tutorial: https://docs.tvm.ai/tutorials/frontend/from_tensorflow.html#import-model

with tf.gfile.FastGFile(model_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    graph = tf.import_graph_def(graph_def, name='')
    # Call the utility to import the graph definition into default graph.
    graph_def = tf_testing.ProcessGraphDefParam(graph_def)
    # Add shapes to the graph.
    with tf.Session() as sess:
        graph_def = tf_testing.AddShapesToGraphDef(sess, OUTPUT_NODE)

mod, params = relay.frontend.from_tensorflow(graph_def)

My problem is that I cannot import the full graph.
If I use OUTPUT_NODE = 'reshape', I get a Relay graph from the input node up to the reshape node.
If I use OUTPUT_NODE = 'output', I get a Relay graph from the bottleneck node up to the output node.

How can I import the full graph from the input node to the output node?


#2

from_tensorflow can take some optional arguments, such as shape, which is a dictionary of input name to shape. This is how I import graphs with multiple inputs.


#3

Even when explicitly specifying the shape of those 4 nodes using shape, I only get the second part.

I added the shapes to the original post. Which shapes do you recommend I specify in the shape parameter?


#4

I made it work now. I did not change any of my TVM code, but I used the Tensorflow Graph Transform tool (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md) to remove the PlaceholderWithDefault node.

bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
  --in_graph=model.pb \
  --out_graph=model_transformed.pb \
  --inputs='input' \
  --outputs='output' \
  --transforms='
remove_nodes(op=PlaceholderWithDefault)
'