[Relay][split] Don't know how to handle type 'tvm.relay.expr.TupleWrapper'


#1

i want to test the relay.split op, but it can’t get results correctly.
The following code is the relay.split test case

def test_spilt():
    dshape = (6, 3)
    dtype = 'float32'
    x_data = np.random.rand(*dshape).astype(dtype)
    intrp = relay.create_executor(ctx=tvm.context('llvm', 0), target='llvm')
    print("xdata:\n", x_data)

    x_var = relay.var("x", relay.TensorType(x_data.shape, "float32"))
    fwd_fun = relay.Function([x_var], relay.split(x_var, [2, 5], 0))
    fwd_fun = run_infer_type(fwd_fun)
    print("=================================")
    print("fwd func\n", fwd_fun)
    print("=================================")
    print("fwd res:\n", intrp.evaluate(fwd_fun)(x_data))
    print("=================================")

This is the Error message

Traceback (most recent call last):
  File "/home/rui.huang/tvm/tests/python/relay/train/test_concat.py", line 216, in <module>
    test_spilt()
  File "/home/rui.huang/tvm/tests/python/relay/train/test_concat.py", line 100, in test_spilt
    fwd_fun = relay.Function([x_var], relay.split(x_var, [2, 5], 0))
  File "/home/rui.huang/tvm/python/tvm/relay/expr.py", line 296, in __init__
    _make.Function, params, body, ret_type, type_params, attrs)
  File "/home/rui.huang/tvm/python/tvm/_ffi/_ctypes/node.py", line 97, in __init_handle_by_constructor__
    handle = __init_by_constructor__(fconstructor, args)
  File "/home/rui.huang/tvm/python/tvm/_ffi/_ctypes/function.py", line 219, in __init_handle_by_constructor__
    values, tcodes, num_args = _make_tvm_args(args, temp_args)
  File "/home/rui.huang/tvm/python/tvm/_ffi/_ctypes/function.py", line 170, in _make_tvm_args
    raise TypeError("Don't know how to handle type %s" % type(arg))
TypeError: Don't know how to handle type <class 'tvm.relay.expr.TupleWrapper'>

My code is the latest version


#2

find out what is the tuple_wrapper in your code (I guess it is split), and .astuple() on it.


#3

I just commented out the TupleWrapper part and I got the correct result. But I don’t know if my change is correct.

code at python/tvm/relay/op/transform.py (line: 537)

def split(data, indices_or_sections, axis=0):
    """Split input tensor along axis by sections or indices.

If indices_or_sections is an integer, the input will be divided equally
along given axis. If such a split is not possible, an error is raised.

If indices_or_sections is a tuple of sorted integers,
the entries indicate where along axis the array is split.

Parameters
----------
data : relay.Expr
    The source array.

indices_or_sections : int or tuple of int
    Indices or sections to split into. Accepts an int or a tuple

axis : int, optional
    The axis over which to split.

Returns
-------
ret : relay.Tuple([relay.Expr, relay.Expr])
    The computed result.
"""
if isinstance(indices_or_sections, int):
    ret_size = indices_or_sections
else:
    ret_size = len(indices_or_sections) + 1
# return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size)
return _make.split(data, indices_or_sections, axis)

#4

It is correct. Tuplewrapper is nothing but syntax sugar for GetItem on python side. Personally I dont think it should exist (instead, one should be able to unwrap every ast).


#5

ok~get it
thanks anyway~