I’m interested in compiling certain layers of a graph for CPU execution and other for GPU. From the Relay API, it seems like this is possible, however, I can’t find any examples showing the proper way to do it. Anyone know where I might find one?
Are you talking about heterogeneous compilation examples? You can find them here:
That is what I was looking for, although I have a related question you might be able to help with. A lot of useful relay features relay on annotations that can be conveniently done from python. However, my main interaction with Relay is through a frontend graph parser that spits out an entire relay function. If I wanted to annotate certain operations in that function what’s the intended method? Do I need to write a full IR pass that includes C++ (similar to quantization)?
@jwfromm Yes, annotation at the expression level from the Relay source is a little annoying if you mainly work on the program obtained from the parser as the graph is usually large. But most of the annotations are for general purpose. If you have some special needs, one thing you can do is probably writing another pass. But again, you can still leverage the available operators and only the way you want to annotate the graph needs to be implemented. Hopefully this is helpful in some sense.
@jwfromm You can just write a simple pass that annotates the program in Python.
import tvm from tvm import relay import tvm.relay.testing from tvm.relay.expr_functor import ExprMutator class ScheduleConv2d(ExprMutator): def __init__(self, device): self.device = device super().__init__() def visit_call(self, expr): visit = super().visit_call(expr) if expr.op == tvm.relay.op.get("nn.conv2d"): return relay.annotation.on_device(visit, self.device) else: return visit def schedule_conv2d_on_gpu(expr): sched = ScheduleConv2d(tvm.gpu(0)) return sched.visit(expr) resnet, params = relay.testing.resnet.get_workload() print(resnet) resnet = schedule_conv2d_on_gpu(resnet) print(resnet) resnet = relay.ir_pass.rewrite_annotated_ops(resnet, 0) print(resnet)