An RuntimeError Exception Is Raised While Compile TensorFlow Frozen Graph By Relay API

Hi there, I’m a newbie for TVM.

At first, I followed the tutorial from_tensorflow.py to compile a TensorFlow frozen graph by using Relay API. It works.

However, the pb file is from TVM GitHub. So I would like to try an official pre-trained model. You may find a lot of pre-trained model from https://github.com/tensorflow/models/blob/master/research/slim/README.md#pre-trained-models.

The model I selected is InceptionV1. It is just a checkpoint file. As I know, TVM can only convert a frozen graph now. So I save it to a frozen graph, and do some inferences to make sure the model works as our expectation. Then I followed tutorial to compile the model I froze. And it failed.

The error message is listed as following:

Traceback (most recent call last):

  File "compile_tf.py", line 95, in <module>
    main(args)

  File "compile_tf.py", line 77, in main
    mod, params = relay.frontend.from_tensorflow(graph_def, layout='NCHW', shape=shape_dict, outputs=[output_node_name])

  File "/root/tvm/incubator-tvm/python/tvm/relay/frontend/tensorflow.py", line 2512, in from_tensorflow
    mod, params = g.from_tensorflow(graph, layout, shape, outputs)

  File "/root/tvm/incubator-tvm/python/tvm/relay/frontend/tensorflow.py", line 2150, in from_tensorflow
    op = self._convert_operator(node.op, inputs, attr, graph)

  File "/root/tvm/incubator-tvm/python/tvm/relay/frontend/tensorflow.py", line 2474, in _convert_operator
    sym = convert_map[op_name](inputs, attrs, self._params)

  File "/root/tvm/incubator-tvm/python/tvm/relay/frontend/tensorflow.py", line 619, in _impl
    raise RuntimeError("If shape operator is used in reshape to "

RuntimeError: If shape operator is used in reshape to express reshape_like, shape_of must be the direct ancestor of reshape when input shape is symbolic.

I had tried other models, such as inception_resnet_v2, inception_v2/v3/v4, even mobilenet, the same exception is raised.

The TVM is installed in a container, the base image is ubuntu:18.04. Here is the version info of my tools. TensorFlow: 1.12.3 TensorFlow models: 833e6939acb42f695b0ae3765f98fe494f06115c TVM: commit 14a5a35882e7369508578705edbc48922b5a0e9a, build with LLVM 8.

Here is the code I used to freeze graph.

import tensorflow as tf    # 1.12.3
from nets.nets_factory import networks_map, arg_scopes_map  # tensorflow/models 833e6939acb42f695b0ae3765f98fe494f06115c
from tensorflow.python.framework import graph_util

model_name = 'inception_v1'
network_fn = networks_map[model_name]
scope_fn   = arg_scopes_map[model_name]

# Create graph by using slim
input_size = network_fn.default_image_size
graph = tf.Graph()
with graph.as_default():
    input_node = tf.placeholder(tf.float32, shape=(None, input_size, input_size, 3))
    with tf.contrib.slim.arg_scope(scope_fn()):
        _, end_points = network_fn(input_node, num_classes=1001, is_training=False)
    output_node = end_points['Predictions']

# Load checkpoint and save as a frozen graph
ckpt = 'inception_v1.ckpt'
with tf.Session(graph=graph) as sess:
    # Restore parameters
    saver = tf.train.Saver()
    saver.restore(sess, ckpt)

    # Freeze model
    graph_def = graph_util.convert_variables_to_constants(
        sess,
        graph.as_graph_def(),
        [output_node.name.split(':')[0]],
    )

    # Save frozen model
    with tf.gfile.GFile('inception_v1_frozen.pb', 'wb') as fp:
        fp.write(graph_def.SerializeToString())

Here is the code I used to compile TensorFlow frozen graph.

import tensorflow as tf
import tvm
import tvm.relay.testing.tf as tf_testing

from tvm import relay

graph_def = tf.GraphDef()
with tf.gfile.GFile('inception_v1_frozen.pb', 'rb') as fp:
    graph_def.ParseFromString(fp.read())
graph = tf.import_graph_def(graph_def, name='')
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

# Get input node name. Based on this assumption: only one input node.
input_node_name = [node.name for node in graph_def.node if len(node.input)==0 and node.op not in ('Const')]
input_node_name = input_node_name[0]
shape_dict = {input_node_name: (1, 224, 224, 3)}

# Get output node name. Based on this assumption: only one output node.
graph_dict = dict()
for node in graph_def.node:
    graph_dict[node.name] = node.input
for name_src in graph_dict:
    found = False
    for name_dst in graph_dict:
        if name_src == name_dst: continue
        if name_src in graph_dict[name_dst]:
            found = True
            break
    if not found:
        output_node_name = name_src
        break

with tf.Session(graph=graph) as sess:
    graph_def = tf_testing.AddShapesToGraphDef(sess, output_node_name)

mod, params = relay.frontend.from_tensorflow(graph_def, layout='NCHW', shape=shape_dict, outputs=[output_node_name])
# An RuntimeError is raised here

Try this patch https://github.com/apache/incubator-tvm/pull/4285

@kevinthesun Great! This patch works. I tried inception_v1, inception_v2, inception_v3, inception_v4, inception_resnet_v2, mobilenet_v1_0.25_128, mobilenet_v1_0.5_160, mobilenet_v1_1.0_224, mobilenet_v2_1.0_224, and mobilenet_v2_1.4_224. All of them works.

Thanks a lot.