[RFC] Printing IR and parameters to pass pipelines

This is to discuss the API for printing Relay IR during compilation. I was advised to open an RFC on the pull request here:

https://github.com/dmlc/tvm/pull/4067

The approach in the patch is to add an optional print_ir parameter to the build config. This will print the initial IR and it will print the IR after each pass if and only if the pass changed the IR - otherwise it will just say that the IR did not change. If a pass is being run due to being a prerequisite of another pass, it will also make a note of this and say which other pass is causing the pass to be run. There is an example invocation and output at the link above.

The alternative API suggestion in the review, instead of a boolean build_ir parameter, is to add a general tracing functionality such that this can be:

  with relay.PassContext(trace=relay.transform.PrintIR()):
      mypass

I note that this cannot replicate the features in the patch as-is, though the features could be replicated with an API like this that has a more general tracing functionality that provides more information to the tracing function, together with a tracing function that captures state and modifies it, e.g. to say if this is the first pass or not (though this is tricky as state should not affect a later compilation, yet there are nested sequences that are not separate compilations - I don’t know if configurations are ever reused). One might also choose to just omit those features.

More involved functionality like that would be too much to type out each time, so instead there might be something like this, where IRTracer() returns a suitable tracing object that does the right negotiation with the compiler to do the tracing correctly (while maybe still allowing just plain functions that do something simpler):

  with relay.PassContext(trace=IRTracer()):
    mypass

This might also be nice for things like stopping compilation once a condition is satisfied, like failing a test. Though printing the IR seems common enough, and cross-cutting enough, that it could be a flag without having to get into this, but let me know. (If this gets very involved, it may go beyond what I’d do for my first patch to TVM though if this is the right approach then that can be determined and then someone (who might be me) can handle it later.)

Some further background on this general parameter topic was offered in the review, here:

[DISCUSS] All-in-one Build API and Pass API Composability

1 Like

Cross-posting the comment:

Imagine in a world where we want to add additional options such as save IRs to the files, collect the IRs so we can introspect them etc. Instead of adding an option to PassContext for each of these features, would be great if we can design a trace callback option that we can just pass in different variants. Whether we should pass in PrintIR pass or other callbacks can be discussed.

I understand that in order to implement what @broune proposed(comparing the IRs in stages), we might either have to make the tracer stateful, or have a more involved signature(pass in both ir before and after). I think the callback signature can be up for discussion and we can list possible proposals.

One consideration would be whether we could reuse the same signature astransform::Pass(so users do not have to learn another API).

We could discuss the alternatives.

As a concrete step to act on this @broune can you champion a few proposals design options and we can discuss a bit like [DISCUSS] Embed more Bound Information into Var or Expr

OK, I took a stab at it. Here are some properties I think are classical compiler virtues for a pass manager:

P1) Can run passes that run on functions in parallel across functions and with different amount of progress of the pass sequence for different functions.
P2) Supports cross-function optimizations with a good ordering for processing of functions when that matters (e.g. for inlining in a traditional compiler this is important).
P3) Offers excellent support in understanding and inspecting what each pass did to the IR, how long it took, how the size of the IR changes through the pass sequence, where changes happen, etc.
P4) Makes it easy to run an IR verifier between each pass where the IR changed or at specified times.
P5) Makes it easy to bisect a failing test case to the pass that makes it fail.
P6) Makes it easy to tell what the pass sequence being executed is.
P7) Makes it easy to experiment with different pass sequences, automatically and manually.
P8) Easy to understand, modify and use and with clear documentation.
P9) Has some support for understanding analysis versus mutating passes and updating versus recalculating analyses as the IR changes.
P10) Is efficient, i.e. allows fast compilation with minimal memory overhead.
P11) Probably more things I’m not thinking of right now.

The change I’m trying to make helps with P3. It would be good to design the API with a variety of these and any other concerns in mind, but it’s a bit beyond what I’m looking to do for this. :slight_smile: So I’ll focus much more narrowly on printing/dumping IR, even though in a wider perspective probably all of this should be considered. Relevant features here include:

F1) Printing the IR, maybe only if it changes, maybe in full or as a diff.
F2) Showing a summary of compilation and what passes did (e.g. size of diff), without showing the IR itself.
F3) Dumping all manner of information, including IR, about each pass to a directory, one file per pass in files that are named with an incrementing pass number followed by the pass name. This allows focusing in on a relevant pass immediately and Unix tools like grep, wc and diff can be used. This has been useful in other compilers.

Both F1 and F2 require being able to do something both before and after a pass, be that knowing the IR before and after to make a diff, do other statistics, see whether the IR changed or measure how long the pass took to run. They also need to know what the passes that are run are, at least their names. F3 needs to know at each pass how many prior passes have run in order to number the files correctly, which is different from just knowing the IR before and after. It is also nice if there is support for saying why a pass is being run, since prerequisites involve running passes that were not in the original sequence of passes that were requested to be run.

Option O1: Handle dumping/printing with specific code in the pass manager
All of the information needed for printing/dumping as laid out above is readily available to the pass manager code and currently only the pass manager (sequantial) has this information. It is fairly straightforward to implement and test, as seen in the PR I sent. I think this is reasonable for right now.

Option O2: Add a pass manager API that just runs between passes and is passed the IR
Most of the above features can be implemented in this way by maintaining state in the functions (call them before() and after()) that run before/after a pass and it can be done with the Pass interface. To print the IR unconditionally before the first pass, the before() function can capture a boolean to say if this is the first time it is being run. It is not possible to determine exactly why a pass is being run, but it would be possible to detect nesting of passes if the before-function of a pass is called before running the before-function of prerequisites or nested passes, and then the before() function could keep a stack to figure out when passes are nested or prerequisites, though these two cases could not be distinguished (and this nesting probably doesn’t lead to the cleanest print-out). Determining whether the IR changed, or generating a diff of the change, can be done by storing the IR in the before() function on the side in a way that the after() can access it and compare to the IR after. Timing a pass can be done in the same way with a timestamp stored from within the before-function and retrieved in the after-function.

This mostly works but it’s awkward for the before() and after() functions to figure out what’s going on. The statefulness also is not as great for running the same sequence of passes again, since then one would have to make sure to reset or use fresh new state.

This probably gets a bit easier with having an object that contains the state and has a before-function, after-function and maybe a reset-function or first()-function, though that then gets closer to the next options.

Very unfortunately, this could not print the name of the pass, since you can’t tell that from just the IR. The function could accept both the IR and the pass as parameters, though then that again gets closer to the next options and then it wouldn’t fit the Pass interface. Of options 2 and 3 I think this variation is the most reasonable for right now, of accepting the IR and the pass.

Option O3: Add a Python API that provides detailed information to a Python function
This is similar to O2, but the functions are passed an object that contains information about compilation, such as why a pass is being run, how many passes have already run, whether this is a nested or prerequisite run of a pass, what the pass stack is currently, what the current pass is, whether the IR changed (for the after-function only), how long the pass took to run, what the IR was before the pass, whether this is before or after the current pass etc.

Essentially, what this is doing is exposing all of the data available inside the pass manager to also be available outside of the pass manager. There is a cost here that expensive information can be collected, e.g. what the previous IR was, which requires storing the previous IR, which can double memory usage if the IR is large, yet none of the before-functions or after-functions might need that information. Then you could imagine that there would be a way to configure what expensive information is available. This is getting to be a big API at that point, though. Instead one might say that the user of this API has to keep track of any expensive information and only cheap information is exposed.

(Also it would be nice if passes would say whether they can changed the IR, so then it’s cheap to see if the IR changed without having to store and compare against the previous IR and a debug mode could cross-check this.)

There are a lot of variations on these options, e.g. all 3 can be combined for different use cases. A reasonable combination of 2 and 3 might be something like having a way to insert a Pass between all other passes that then just gets the IR, or have a function that takes an object with a lot of information and can print all of that.

I would personally assume that 1 is going to be part of the picture in the end and that cross-cutting concerns like timing passes and dumping information will be a good fit for it, so it’s not necessary to go to 2 or 3 yet just to show information about the passes and the IR, but certainly this has to fit into the overall design goals of TVM and I’m not as much into the details of that at this time.

Great points. On a meta comment, it is relatively rare to have a compiler pipeline being exposed in python, and the added flexibility might makes things interesting(e.g. the ability to plugin callbacks) and less conventional.

One potentially interesting question, which we repetitively ask ourselves: will ML compiler landscape become similar to neural networks themselves, that we will need easy pragmatic APIs (pytorch.Modules or Keras) to construct the pipelines and try things out. e.g. would my customized quantization pipeline becomes something like resnet models. We think such direction could happen, and that means not only we want to learn lessons from typical compiler pass managers(which are more or less like the prototxt config files for first generation deep learning frameworks), but more importantly deep learning framework APIs.

The before/after seems to be on the right direction, as it makes adding new features more composable. Timing, IRVerifier, dumping can all be implemented in the same way. The only question is that should we keep the same API interface as Pass, which seems to be hard atm Module->Module.
The additional needs of pass name, and perhaps before and after info calls for a bit more complicated function signature.

Here is one example stab of such kind of function(which can be passed as a PackedFunc).

# signature of pass callback
# pass state can be before, after
def pass_callback(module, pass_name, pass_state):
    pass

@broune Thanks for the great points. I kept in mind some of properties you mentioned, e.g. preserving some pass state, IRVerifier, phase ordering, etc. The before and after thing is interesting, but I haven’t really thought about it. One other possibility is we could bake the state related info into the pass context, but this could complicate the build interface.

I think there are two distinct questions here: (Q1) How would a fully general API look that is great for customizing functionality for users and (Q2) What’s the best way to do printing at this time as a specific feature? We’re discussing how to solve Q2 in terms of implementing Q1 and then implementing Q2 in terms of that, though it may or may not be that even if Q1 were already available, it might not necessarily be the best way to implement Q2. Though if the general decision is to have no flag-like functionality at all, then Q1 is the only option.

I’d also mention that I think you’re interested in more features of this API than just those relevant for printing/accessing information. I think you also want to be able to disable/enable, insert and modify passes, and maybe even the module, in the callback, so it gets more involved than what we’ve talked about so far. Unless I’m getting that wrong?

For Q2-via-Q1, I’d pass in the pass itself rather than the pass name. The name can be easily extracted from the pass, but perhaps the function is interested in more properties of the pass than the name. To replicate the current printing features in the PR, it would also be necessary to have more fields available on the pass state, beyond before/after, such as the reason the current pass is being run in case it is a pre-requisite. That would end up with a signature as (module, pass, info).

There is also a question of when the call back runs. E.g. as far as I can tell from the constant folding code, it calls FuseOps multiple times on sub-pieces of the code and it’s not clear to me if the callbacks should be called for every such invocation. Meanwhile, if there is a nested Sequential, then probably they should be called within that. So they could be called, but the situation could be communicated to the callback via the info parameter, which then lets it decide what it wants to do. This design is leading down a path where every callback has to do substantial work just to decide when it wants to do something, though maybe this could be improved by having a function like IgnoreNested(foo) or BeforeOnly(foo) that takes a callback and returns a wrapped callback where nested calls are ignored and not passed on to the nested callback or only Before-calls are passed on, respectively.

OK. Let me try to summarize. The goal is indeed to solve Q2 and implement Q1 using Q2, which I believe should be the way to go.

I think we don’t have to complicate the situation by adding too many additional infos, as at some time point we need to stop, and make an engineering trade-off. So how about we go with the following signature.

def pass_callback(module, pass_info, is_before):
     pass

In terms of implementation, we can just invoke those callbacks in the operator() in the leaf nodes: ModulePass and Function pass. Noting that this also corresponds directly to the semantics of “tracing”, as we are simply tracing every leaf pass being invoked, and can be used to implement the timing, logging features needed.

Note that it does not provide additional info such as requisite, etc. But like you said, moving the API toward that direction would add another level of complication that may not be necessary.

OK. What fields are on pass_info? Or is that just a name for the pass itself?

Let’s call a pass “traced” if the callback will be called when that pass runs. There is a question of when passes are traced, depending on how they get run. Some possible recursive rules for what passes are traced:

  1. All passes listed in the outer Sequantial are traced (duh).
  2. All passes listed in a traced Sequantial are traced.
  3. All passes that are declared as a prerequisite of a traced pass are traced.
  4. All passes that are run directly by a traced pass are traced (e.g. constant folding creates a Sequantial while it runs, that contains FuseOps, and runs that Sequantial several times on sub-parts of the module)

I would suggest rule 1-3 but not 4, since 4 can cause a number of callbacks that is as large as the module is. That’s typically not interesting information, even if it will sometimes be. Though seems like just putting this into ModulePass and Function pass would include rule 4? Maybe that’s OK. I do think it will lead to large and unwieldy printouts.

What’s the API for setting a trace? Is it something like this?

with relay.PassContext(trace=MakePrintTrace()):
      ...

PassInfo only contains name, opt_level, and required passes. We can obtain a pass from the registry based on the name. I would suggest we directly check traced in the ModulePass and FunctionPass as well so 4 can also be covered.

i think 4 is good as it resembles all the “traces” we care about, and that is usually what trace means. If we really want to suppress certain tracing, we could make a no-tracing context(by copying the current context and change the trace argument).

I do not have strong opinions of API. Perhaps we could get more ideas, it is always good to make API intuitive and easy to type out, here is one attempt :slight_smile:

with relay.PassContext(trace=relay.trace.PrintIR()):
    pass

OK, to summarize the proposal is then an API like this:

with relay.BuildConfig(trace=relay.trace.PrintIR()):
    pass

with trace requiring a signature like this:

def pass_callback(module, pass_info, is_before):
     pass

And the trace function is called before and after each pass from

ModulePassNode::operator()
FunctionPassNode::operator()

but not SequantialNode::operator(). Is that right?

Maybe later we can allow trace= to also work for a list of trace functions.

Looks good, thanks @broune

@broune can you summarize by concluding the thread and let us move to the implementation :slight_smile: ?