Based on the offline discussion with @matt-arm, users may have a requirment to match a pattern with constant nodes. For example, we may have a conv2d op with two arguments.
%0 = nn.conv2d(%x, %w)
After we bind the second argument with constants, it becomes:
%0 = nn.conv2d(%x, meta[relay.Constant])
Users may only want to match the second one in case they only want to support the conv2d with constant weights.
With the pattern language, we have several appraoches to achieve this goal:
A.1 Using Check Functions
We can implement a check function to check if a specific argument in the matched subgraph is a constant node. This solution is already available in upstream. An example can be found here:
The problem for this solution is that the check function implementation might be tedious if the pattern is complex.
is_input, we may enhance pattern language to support
is_const so that we can support the following pattern:
conv2d = is_op('nn.conv2d')(wildcard(), is_const()) pattern = is_op('nn.bias_add')(conv2d, wildcard())
A.3 Supporting All Nodes in Patterns
This should be in A.2 as well but a bit out of scope. A more general solution could be supporting all nodes in patterns. For example, pattern nodes like
TupleGetItemPattern explicitly check the node type. We can improve the pattern nodes to support all types of nodes in the TVM node system so that we can solve this problem in a more general way:
conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern()) pattern = is_op('nn.bias_add')(conv2d, wildcard())
One miner extension to this solution is to create a consistant alias for all pattern nodes. In other words, we do not expect users to use
ConstantPattern diectly but
Any comments and suggestions are welcome.