[PatternLang] Match Constant Nodes

Based on the offline discussion with @matt-arm, users may have a requirment to match a pattern with constant nodes. For example, we may have a conv2d op with two arguments.

%0 = nn.conv2d(%x, %w)

After we bind the second argument with constants, it becomes:

%0 = nn.conv2d(%x, meta[relay.Constant][0])

Users may only want to match the second one in case they only want to support the conv2d with constant weights.

With the pattern language, we have several appraoches to achieve this goal:

A.1 Using Check Functions

We can implement a check function to check if a specific argument in the matched subgraph is a constant node. This solution is already available in upstream. An example can be found here:

The problem for this solution is that the check function implementation might be tedious if the pattern is complex.

A.2 Supporting is_const

Similar to is_input, we may enhance pattern language to support is_const so that we can support the following pattern:

conv2d = is_op('nn.conv2d')(wildcard(), is_const())
pattern = is_op('nn.bias_add')(conv2d, wildcard())

A.3 Supporting All Nodes in Patterns

This should be in A.2 as well but a bit out of scope. A more general solution could be supporting all nodes in patterns. For example, pattern nodes like TuplePattern and TupleGetItemPattern explicitly check the node type. We can improve the pattern nodes to support all types of nodes in the TVM node system so that we can solve this problem in a more general way:

conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern())
pattern = is_op('nn.bias_add')(conv2d, wildcard())

One miner extension to this solution is to create a consistant alias for all pattern nodes. In other words, we do not expect users to use TuplePattern / ConstantPattern diectly but is_tuple / is_const, etc.

Any comments and suggestions are welcome.

cc @mbrookhart @matt-arm @zhiics @masahi @tqchen

Update: For A.2, I’ve made a POC here: https://github.com/comaniac/tvm/tree/add_const_to_pattern

With the POC, we can match the following pattern:

conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern())
pattern = is_op('nn.bias_add')(conv2d, wildcard())

Hi Cody,

There are some examples in the unit tests of matching const nodes with specific values: https://github.com/apache/incubator-tvm/blob/a072da0588c542757d2815832b7f010f530b2428/tests/python/relay/test_dataflow_pattern.py#L685-L759

Adding ConstantPattern with an optional value is probably a good, quick solution for the case when you need a constant but don’t care what the value is.

I agree that extending the pattern language makes more sense, right now it’s really focused on chained CallNodes. It’s missing a lot of the more complex functionality (Functions, Match, If, While, etc). I’m not sure how much of that we want to be matching on, but I have no problems adding them to the langauge as needed.

:slight_smile: Want to open that POC as a PR?

Thanks for the pointer! So I’ll open a PR this afternoon including the following:

  • Add ConstantPattern
  • Add descriptions to the doc about how to match constant with or without a specific value.
  • Add description to the doc about welcome to raise an issue or open PRs to add more pattern nodes.

Meanwhile, let me know if you have any else to add :slight_smile:

PR filed