[relay] to_a_normal_form fails on a prelude function


#1

The following script could reproduce the problem:

from tvm import relay
from tvm.relay.ir_pass import to_a_normal_form
from tvm.relay.prelude import Prelude

mod = relay.Module()
p = Prelude(mod)
it = mod['iterate']
f = to_a_normal_form(it, mod)

The error message is like the following:

In `iterate`:
v0.0.1
fn <%a>(%f: fn (%a) -> %a, %x: int32) -> fn (%a) -> %a {
  let %x1: meta[relay.IncompleteType][0] = 0
  let %x2: meta[relay.IncompleteType][1] = equal(%x, %x1)
  let %x14: meta[relay.IncompleteType][11] = if (%x2) {
    let %x4: meta[relay.IncompleteType][2] = fn <%a1>(%x3: %a1) -> %a1 {
      %x3
    }
    %x4
  } else {
    let %x5: meta[relay.IncompleteType][3] = 1
    let %x6: meta[relay.IncompleteType][4] = subtract(%x, %x5)
    let %x7: meta[relay.IncompleteType][5] = @iterate(%f, %x6)
    let %x12: meta[relay.IncompleteType][9] = fn <%c, %b, %a2>(%f1: fn (%b) -> %c, %g: fn (%a2) -> %b) -> fn (%a2) -> %c {
      let %x11: meta[relay.IncompleteType][8] = fn (%x8: %a2) -> %c {
        let %x9: meta[relay.IncompleteType][6] = %g(%x8)
        let %x10: meta[relay.IncompleteType][7] = %f1(%x9)
        %x10
      }
      %x11
    }
    let %x13: meta[relay.IncompleteType][10] = %x12(%f, %x7) Incorrect number of type args in (nullptr): Expected 0but got 3;
    %x13
  }
  %x14
}

It looks that type inference fails in this case. @jroesch, @MarisaKirisame, @slyubomirsky, @joshpoll, is this a bug?


#2

It’s the same error as my erro: Error "Incorrect number of type args" in Relay type inferencer. It fails at type checking when we add the function to a module.


#3
diff --git i/python/tvm/relay/prelude.py w/python/tvm/relay/prelude.py
index 17df61750..af0497e38 100644
--- i/python/tvm/relay/prelude.py
+++ w/python/tvm/relay/prelude.py
@@ -482,8 +482,8 @@ class Prelude:
         with open(prelude_file) as prelude:
             prelude = fromtext(prelude.read())
             self.mod.update(prelude)
-            self.id = self.mod["id"]
-            self.compose = self.mod["compose"]
+            self.id = self.mod.get_global_var("id")
+            self.compose = self.mod.get_global_var("compose")

@zhiics Maybe you can try this patch, it fixes the issue for me.


#4

@wweic Thanks. It solves my problem as well.