Missing memoization in ExprFunctor

Currently I’m about to merge a PR https://github.com/apache/incubator-tvm/pull/5310 which adds memoization logic for ExprFunctor in two different derived classes.

I’ve also noticed that there are many places in the code base that have the same duplicated memoization logic for ExprFunctor, e.g

Since ExprMutator does have a memo built in, I wonder why ExprFunctor doesn’t have it as well? I think we can clean up existing derived classes a lot. Is there a deep reason ExprFunctor cannot have a built in memoization?

Since ExprFunctor doesn’t have memo, derived classes need to override VisitExpr(const Expr& expr) method and implement dispatching logic as well, in addition to memo logic.

I think it is a good item to add to our refactoring effort. cc @tqchen @zhiics

ExprFunctor is a more generic base class for dispatching, while itself can be used to build recursive visitors, the visitor is not the only use case for the ExprFunctor. For example, we use the functor to build a dispatching Rewriter callbacks in the recent non-recursive vistor.

Memoization could also have un-desired cosquences, if the translation depends on the context(e.g. if you are in the if and else branch, the results might be different).

Due to these considerations, it is better to allow “no-surprise” and leave functors as simple as possible(thus no memoization)

1 Like

hmm, then does it make sense to have a another derived class, say MemoizedExprFunctor, that does have a memo logic built in, for a simple use case in derived classes that requires memoization? The PR I saw today could derive from this instead of ExprFunctor, for example.

Basically for encapsulating already duplicated logics in our codebase. The rest of code that needs more care are not going to be affected.

Given that the ExprFunctor also have additional arguments(besides the first argument), it is unclear how do we define such Memoization, unless we restrict the arguments to only contain the single one. The addtional cost for memo is not taking a lot of locs, so I think either way is fine(considering the tradeoff between reuse vs reduce the layer of abstractions)

Note that in addition to the memo logic itself, we can also remove VisitExpr(const Expr& expr) override in the derived classes (which to me seems only exist for memoization), and hence the dispatching logic like this can be also removed:

I think this dispatching logic is already handled by ExprFunctor. So if the base class has memo logic built in, dervied classes don’t need to VisitExpr(const Expr& expr) override at all. The PR https://github.com/apache/incubator-tvm/pull/5310/ also has this dispatch logic in VisitExpr(const Expr& expr) override.

Moreover, without memoization in the base class, all external codegen (most of which is not public and hence we don’t know how many) needs to implement memoization in the codegen backend. I actually hit “exponential blow up” error when compiling resnet with my codegen because I was missing memoization. So I think reuse opportunity of memoization in the base class is high.

It seems all use of ExprFunctor in the code base that adds memoization on top takes a single const Expr& argument, so I think single argument restriction is reasonable.

So given these reason I think cost-to-benefit ratio of adding MemoziedExprFunctor is in favor. Anyway I’ll send a PR and we can decide if it is a good addition.

That dispatching logic ca certainly be simplified as a one-liner, which will reduce the memo logic addition to be about 10 loc.

Result VisitExpr(const& expr) final {
  auto it = memo_.find(expr);
  if (it != memo_.end()) {
    return it->second;
  }
  res  = ParentClass::VisitExpr(expr);
  memo_[expr] = res;
  return res;
}

in this case, I think both makes sense. It is ok to add such a base class, and document its restrictions. To be careful in the beginning, let us only put it as an internal header of relay(inside src).

Yes, the point is each derived class ends up having the exact same 10 loc. Until now we have 2 or 3 cases so that might be overlooked, but the PR by @zhiics and from my own experience made me realized that all external codegen would duplicate this 10 loc. That wouldn’t be great.

Sounds good.

While it is always possible to introduce more re-use by adding new layers of abstractions. There is also additional cost of introduce more abstraction(of sub-classing). So it is usually a trade-off.

In my experience, 10 loc duplication could also be fine, as long as this pattern is clearly documented. We do need to take effort to eliminate the existing implementations that are overlooked though, once they are refactored, and when others will look for examples, they won’t copy the wrong ones.

This being said, more thoughts need to be put into about the naming(We might need a different name other than the functor), and the implementation if we want to make a public helper. In particular, it would be great to discuss

Some Non-Recursive Alternatives

For example, here are some alternatives to implement the codegen.

  • C0: Use MixedModeVisitor, and store the transformed bindings(output) into a Map, this removes the need of implementing another recursive visitor.
  • C1: Create a TempExpr to store the Result, directly use MixedModeExprMutator to translate the Expr to the related TempExpr, for further consumption (quite close to the memoized version and add non-recursive feature)
  • C2: Convert to A-normal form (every call binds to a let), directly generate a map from var to the Output via Visiting the ANF. note that A-normal form explicitly generates bindings for each Call, and removes the need of memoization.
  • C3: On top of C2, discuss the additional normal form in relay to provide additional info needed for codegen

Note that C0, C2, C3 all requires maintaince of some form of memo(for Vars, or mapped results), thus the similar amount of 10 loc additions, which are necessary. Although the main benefit is non-recursiveness.

Yeah, I am not a big fun of introducing this base class either as I think the only duplication code would be really just the caching map. If you are concerning about that 10 locs. I can actually justremove them and replace it by calling the Functor<R(Expr)>::VisitExpr(expr); I did it that way because I wanted to give an error for unsupported node. Alternatively, I can have a checker for that.

You can overload VisitDefaultExpr to add that error(for unsupported code) if you want a custom error message

Thanks @masahi @zhiics for great discussions so far, would be great to also get your thoughts wrt to C0, C1, C2, C3 style in the long run, and whether do we need non-recursive support for this part

ahh, I didn’t notice we have this one. Thanks.

To be honest, among C0-C3 I would not want to introduce ANF to codegen. This means we either want to do ANF on the whole program or run the pass internally in the extern codegen to convert it. If we run it on the whole program, I think some passes that work on the DFG would not work well/or break. If we invoke it internally, we would want developers to be aware of this and perform this conversion explicitly.

Actually, I also didn’t see how much help non-recursive form would bring us here as these ASTs are usually simple (at least the ones we currently use), i.e. only some nodes need to be handled. It looks that MixedModeVisitor would need to do similar things. I may miss something here because I haven’t used MixedModeVisitor yet.

Since the new base class would be as simple as the one below, I don’t think there is much of abstraction cost. I don’t see why we should prefer duplicating the same VisitExpr(const Expr& n) over this solution.

template <typename R>
class MemoizedExprFunctor : public ::tvm::relay::ExprFunctor<R(const Expr&)> {
  using BaseFunctor = ::tvm::relay::ExprFunctor<R(const Expr&)>;

 public:
  virtual ~MemoizedExprFunctor() {}

  virtual R VisitExpr(const Expr& n) override {
    CHECK(n.defined());
    auto it = memo_.find(n);
    if (it != memo_.end()) {
      return it->second;
    }
    auto res = BaseFunctor::VisitExpr(n);
    memo_[n] = res;
    return res;
  }

 protected:
  std::unordered_map<Expr, R, ObjectHash, ObjectEqual> memo_;
};

I think we can agree that in general, code duplication is something to be avoided, no matter what the loc. Since the discussion so far mostly talks about loc argument, let me also suggest another point of view, which to me matters way more than reducing 10 loc.

As a user of external codegen, I don’t want to care about memoization. I expect memomization to be taken care of by Relay and only implement overrides that I am actually interested in. If I want a custom memoization, I can always add VisitExpr(const Expr& n) override, but this shouldn’t be the default. Reducing the bare 10 loc or whatever line counts from my codegen doesn’t make my life much better, but I think removing the cognitive load of having to do memomization is worth it. It benefits all future backend implementers.

That said, if we need to consider adding non recursive support in the base class, I do need to give a second thought to my proposal. It would involve a lot of complexity that I was not expecting, since my intention in this discussion is to propose a simple QOL improvement over existing code that uses the recursive visitor. Since I haven’t met a stack overflow issue with existing implementation, I’m not comfortable with adding extra complexity which may end up unnecessary.

I would say discussion of non recursion support is orthogonal to this one and applies to other use case of ExprFunctor in general.

I have another thought on this, how about just put this one in the backend/utils.h since the current usage of them would be for the code under there? For general passes, it might be different though (like, to_a_norm_form, to_cps, PE, etc)

Seems that a general concensus so far is we can put such as class that @masahi suggested as an internal header. It is always good to discuss the alernatives, the tradeoffs. Such discussions helps us to reach a better code quality overall.

When there are disagreements. I usually find it to beuseful for us to separate the factual accessments(e.g. additional base class reduces impl burden, the cost of duplication is 10 loc) from our preference(e.g. we think the cost of duplication can be afforded, or not).

More often we can find that people agree on the factual accessments, but not necessarily the preference. When we start to agree on more things, it is easier to reach concensus and better design decision overall.

As for the name of the MemoizedExprFunctor itself, I feel that the signature has divered from “functor” enough that perhaps we should chose another name. So that the “Functor” always takes a function signature as its template parameter.

One possible candidate i can come up with is MemoizedExprTranslater<R>, would love to see if we have other thoughts.

Yes, that’s where I’d put this class in, given the current usage of ExprFunctor.

Makes sense and MemoizedExprTranslater<R> sounds good to me. I’d change the typename to OutputType to make the meaning of type param obvious.