Hi All,
I’ve written a simple code that traverses Relay program and prints each node, after that I want to construct a NEW RELAY module using the traversed nodes.
Following is the code that I have. However, currently, it is throwing an error. I do not know if the following is the best way to do it, but basically, I’d like to traverse the nodes in Relay program and create another equivalent Relay program using the nodes. If you know a better way of achieving this, please let me know.
# an empty Relay module
mod_new = relay.Module()
def print_and_create_new_module(node):
""" This function prints node and attempts to create a new relay Module"""
print("\n")
# Create a module using the node
temp_mod = relay.Module.from_expr(node)
# Update the mod_new with the temp_mod
mod_new.update(temp_mod)
print("Current node: {}".format(node))
print("\n")
def traverse():
# Create a simple Relay program.
c = relay.const(2.0, "float32")
x = relay.var("x", "float32")
y = relay.add(c, c)
func = relay.Function([x], y)
mod = relay.Module({"main": func})
# traverse the relay IR and create a new RELAY module with every traversed node.
relay.analysis.post_order_visit(mod['main'], lambda node: print_and_create_new_module(node))
Error:
tvm._ffi.base.TVMError: Traceback (most recent call last): File “C:\anaconda\envs\python3.8.1_test\lib\site-packages\tvm-0.7.dev0-py3.8-win-amd64.egg\tvm_ffi_ctypes\function.py”, line 72, in cfun rv = local_pyfunc(*pyargs) File “C:/repos1/tvm_test/relay_traverse_graph.py”, line 37, in relay.analysis.post_order_visit(mod[‘main’], lambda node: print_and_create_new_module(node)) File “C:/repos1/tvm_test/relay_traverse_graph.py”, line 22, in print_and_create_new_module mod_new.update(temp_mod) File “C:\anaconda\envs\python3.8.1_test\lib\site-packages\tvm-0.7.dev0-py3.8-win-amd64.egg\tvm\relay\module.py”, line 131, in update return _module.Module_Update(self, other) File “C:\anaconda\envs\python3.8.1_test\lib\site-packages\tvm-0.7.dev0-py3.8-win-amd64.egg\tvm_ffi_ctypes\function.py”, line 207, in call raise get_last_ffi_error() File “C:\repos1\tvm\src\ir\module.cc”, line 182 TVMError: Check failed: (*it).second == var (GlobalVar(main) vs. GlobalVar(main)) :