Error "Incorrect number of type args" in Relay type inferencer [Resolved]


#1

The Relay program is simple, an add_one function that adds 1 to input parameter. add_two is generated by compose(add_one, add_one).

def test_compose():
    mod = relay.Module()
    p = Prelude(mod)

    compose = p.compose

    # remove all functions to not have pattern match to pass vm compilation
    # TODO(wweic): remove the hack and implement pattern match
    for v, _ in mod.functions.items():
        if v.name_hint == 'compose':
            continue
        mod[v] = relay.const(0)

    sb = relay.ScopeBuilder()
    x = relay.var('x', 'float32')
    x1 = sb.let('x1', x)
    xplusone = x1 + relay.const(1.0, 'float32')
    sb.ret(xplusone)
    body = sb.get()
    add_one = relay.GlobalVar("add_one")
    add_one_func = relay.Function([x], body)    

    add_two = relay.GlobalVar("add_two")    
    add_two_func = compose(add_one_func, add_one_func)

    mod[add_one] = add_one_func
    mod[add_two] = add_two_func

    f = relay.Function([], add_one(relay.const(1.0)))
    mod[mod.entry_func] = f

    result = veval(mod)()
    print(result)

Type inferencer fails with the following error:

======================================================================
ERROR: test_vm.test_compose
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/wweic/workspace/tvm/tvm_env_3.6.3/lib/python3.6/site-packages/nose/case.py", line 198, in runTest
    self.test(*self.arg)
  File "/Users/wweic/workspace/tvm/tests/python/relay/test_vm.py", line 252, in test_compose
    result = veval(mod)()
  File "/Users/wweic/workspace/tvm/python/tvm/relay/backend/vm.py", line 142, in _vm_wrapper
    return _eval_vm(self.mod, self.ctx, *args)
  File "/Users/wweic/workspace/tvm/python/tvm/relay/backend/vm.py", line 105, in _eval_vm
    result = _vm._evaluate_vm(mod, ctx.device_type, ctx.device_id, *cargs)
  File "/Users/wweic/workspace/tvm/python/tvm/_ffi/_ctypes/function.py", line 209, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) 9   libtvm.dylib                        0x0000000114fefe9b tvm::relay::TypeInferencer::GetType(tvm::relay::Expr const&) + 91
  [bt] (7) 8   libtvm.dylib                        0x0000000114ff8428 tvm::relay::ExprFunctor<tvm::relay::Type (tvm::relay::Expr const&)>::VisitExpr(tvm::relay::Expr const&) + 168
  [bt] (6) 7   libtvm.dylib                        0x0000000114fff142 tvm::IRFunctor<tvm::relay::Type (tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Type (tvm::relay::Expr const&)>*)>::operator()(tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::relay::Type (tvm::relay::Expr const&)>*) const + 338
  [bt] (5) 6   libtvm.dylib                        0x0000000114920460 std::__1::__function::__func<tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::relay::Expr const&)>::InitVTable()::'lambda4'(tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::relay::Expr const&)>*), std::__1::allocator<tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::relay::Expr const&)>::InitVTable()::'lambda4'(tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::relay::Expr const&)>*)>, tvm::Array<tvm::Tensor, void> (tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::relay::Expr const&)>*)>::operator()(tvm::NodeRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::relay::Expr const&)>*&&) + 32
  [bt] (4) 5   libtvm.dylib                        0x0000000114ff9bdd tvm::relay::TypeInferencer::VisitExpr_(tvm::relay::CallNode const*) + 1069
  [bt] (3) 4   libtvm.dylib                        0x0000000115005110 tvm::relay::TypeInferencer::GeneralCall(tvm::relay::CallNode const*, tvm::Array<tvm::relay::Type, void>) + 2656
  [bt] (2) 3   libtvm.dylib                        0x0000000115002673 tvm::relay::TypeInferencer::ReportFatalError(tvm::NodeRef const&, tvm::relay::Error const&) + 179
  [bt] (1) 2   libtvm.dylib                        0x0000000114ccdcc7 tvm::relay::ErrorReporter::RenderErrors(tvm::relay::Module const&, bool) + 5399
  [bt] (0) 1   libtvm.dylib                        0x00000001148851a9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  File "/Users/wweic/workspace/tvm/src/relay/ir/error.cc", line 132
TVMError:
Error(s) have occurred. We have annotated the program with them:

In `add_two`:
v0.0.1
fn () -> fn (float32) -> float32 {
  let %x: meta[relay.IncompleteType][0] = fn (%x1: float32) -> float32 {
    let %x2: meta[relay.IncompleteType][1] = let %x11: float32 = %x1
    let %x3: meta[relay.IncompleteType][2] = fn (%p0: float32, __dict__=meta[StrMap][0]) -> float32 {
      add(%p0, 1f)
    }
    let %x4: meta[relay.IncompleteType][3] = %x3(%x11)
    %x4
    %x2
  }
  let %x5: meta[relay.IncompleteType][4] = fn <%c, %b, %a>(%f: fn (%b) -> %c, %g: fn (%a) -> %b) -> fn (%a) -> %c {
    let %x6: meta[relay.IncompleteType][5] = fn (%x7: %a) -> %c {
      let %x8: meta[relay.IncompleteType][6] = %g(%x7)
      let %x9: meta[relay.IncompleteType][7] = %f(%x8)
      %x9
    }
    %x6
  }
  let %x10: meta[relay.IncompleteType][8] = %x5(%x, %x)Incorrect number of type args in (nullptr): Expected 0 but got 3;
  %x10
}

cc @jroesch @MarisaKirisame @slyubomirsky


[relay] to_a_normal_form fails on a prelude function