Hi folks,
It seems that tf.graph_util.convert_variables_to_constants will mess with relay TF FE. Specifically, since switch nodes get removed from the graph by the cmd, TF FE will fail to match merge nodes with corresponding switch node (right now it seems they’re matched by name here)
This is observed with a simple test model with nested branches. Attached is a python script to reproduce:
import tensorflow as tf
def gen_nested_cond():
a = tf.constant(2., dtype = tf.float32)
b = tf.constant(3., dtype = tf.float32)
def fn1(a, b):
def nest_fn1(a, b):
return tf.add(a, b)
def nest_fn2(a, b):
return tf.multiply(a, b)
res = tf.cond(tf.less(a, b), lambda: nest_fn1(a, b), lambda: nest_fn2(a, b))
return res
def fn2(a, b):
return tf.add(a, b)
r = tf.cond(tf.less(a, b), lambda: fn1(a, b), lambda: fn2(a, b))
with tf.Session() as sess:
print("start save model:")
graph_def = tf.get_default_graph().as_graph_def(add_shapes=True)
graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, ['cond_1/Merge']) # This line breaks TF FE
with tf.io.gfile.GFile("./nested_cond_new_const.pb", "wb") as f:
f.write(graph_def.SerializeToString())
Is this expected behavior, or am I using/understanding it wrong?