Hello everyone,
I know most of the usage of Relay’s pattern language has been to create “MergeComposite” functions to offload to BYOC implementations.
- Q1: I think I remember reading somewhere “MergeComposite” regions are ignored by Relay->TVM lowering process. Is this generally true or is it only true for those which are annotated with an ‘external compiler tag’?
I want to use Relay’s pattern language to basically do the same as the BYOC examples, but following the Relay->TVM lowering process.
I looked at the Pattern Rewritting documentation, which seems to offer the “find pattern and replace by” infrastructure.
Now one thing I realized is that the return of the callback
is a “standard” relay.op
. This makes sense in the example since the Relay program is the decomposed form of batch norm and the replacement pattern is a known Relay function.
In my case, I want to replace a “standard” Relay subgraph with a call to a Relay operator that I have implemented with TE (i.e. using te.compute
). This would be similar to operator fusion, but I want to fuse operators that TVM won’t fuse. An illustrative example would be if I want to fuse a complete inception block.
- Q2: Is this a misuse of the Relay level pattern matching infrastructure? should I be using something else?
I seem to have a problem getting it to work though, because the process of registering (?) a new Relay op is not clear after reading Add Operator To Relay Tutorial. Especially, because I thought there was a way to not have to define the Relay operator at the C++ level.
- Q3: Is there a way to register a new Relay operator using only the Python API?
I did a small trial an error script trying to reinvent the conv2d Relay node. It doesnt work since I return a Tensor (which is given by the compute inside my_conv2d
) and it expects a Relay operation.
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
from tvm import te
from tvm import topi
@relay.op.register_compute("relay.op.my_conv2d") #Q3 ==> why doesn't this work?
def my_conv2d(data,kernel):
ishape = topi.util.get_const_tuple(data.shape)
kshape = topi.util.get_const_tuple(kernel.shape)
oshape = (ishape[0], (ishape[1] - kshape[2]) + 1, (ishape[2] - kshape[3]) + 1, kshape[0]) #I know this arithmetic is not complete
kh = te.reduce_axis((0, kshape[-1]), name='kh')
kw = te.reduce_axis((0, kshape[-2]), name='kw')
ic = te.reduce_axis((0, kshape[-3]), name='ic')
res = te.compute(
oshape,
lambda CONV2D_N, CONV2D_H, CONV2D_W, CONV2D_C: te.sum(
data[CONV2D_N, CONV2D_H+kh, CONV2D_W+kw, ic] *
kernel[CONV2D_C, ic,kh, kw ],
axis=[kh, kw, ic]),
name="res",tag="my_conv2d")
return res
def min_relay_prog():
x = relay.var('x', shape=(1, 224, 224, 3))
w = relay.var('w', shape=(16, 3, 3, 3))
conv2d = relay.op.nn.conv2d(x, w,data_layout="NHWC")
return conv2d
class MyPattern(DFPatternCallback):
# A callback class to rewrite the matched pattern to a my_conv2d
def __init__(self):
super(MyPattern, self).__init__() #This line is missing in tutorial
self.pattern = self.look_for_pattern()
def look_for_pattern(self):
conv2d_p = is_op('nn.conv2d')(wildcard(), wildcard())
return conv2d_p
def callback(self, pre, post, node_map):
return my_conv2d(te.placeholder((1, 224, 224, 3)),te.placeholder(shape=(16, 3, 3, 3))) #Error Expected RelayExpr got Tensor
#return relay.op.my_conv2d(...) relay.op doesnt have my_conv2d attribute
template = MyPattern()
prog = min_relay_prog()
print(template.look_for_pattern().match(prog)) #Pattern matches
from tvm.relay.dataflow_pattern import rewrite
out = rewrite(template, prog)
print(out)
- Q5: the tutorial has in the
callback
method the following expressionreturn relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.asnumpy().item())[0]
. Why do we need the [0] & how would we handle multi-output Relay operators?
Just in case you missed it, in the __init__()
we have to call the super. Otherwise another error is raised about missing attributes (in this case require_datatype
)
- Q6: What is this
require_datatype
attribute and when would we need to set it to something specific?
Sorry for the long post.