[TVM] How to insert a call to LoweredFunc into another LoweredFunc?


Hi. I’m experimenting with Call nodes. In particular, I’d like to construct a code which would call a function from another function in C++.

This area is happened to be beyond the scope of examples, so my attempt below is more like a guess, and it doesn’t work, unfortunately. Could you please review the code and help me to fix the cause of the SegFault?

The below program is expected to define a vecadd function for adding two tensors and double function to take one argument and call vecadd to double it.

Full sources are here https://gist.github.com/grwlf/38217ceb345fb4c106e52e92fc6706d3

using namespace std;
using namespace tvm;

int main()
  BuildConfig config = build_config();

  auto n = var("n");
  Array<Expr> shape = {n};
  Tensor A = placeholder(shape, Float(32), "A");
  Tensor B = placeholder(shape, Float(32), "B");
  Tensor X = compute(shape, FCompute([=](auto i){ return A(i) + B(i); } )) ;

  auto vecadd_lowered = ({
    Schedule s = create_schedule({X->op});
    std::unordered_map<Tensor, Buffer> binds;
    auto args = Array<Tensor>({A, B, X});
    auto lowered = lower(s, args, "vecadd", binds, config);

  cerr << "VECADD_LOWERED" << endl
       << "==============" << endl
       << vecadd_lowered[0]->body << endl;

  Tensor C = placeholder(shape, Float(32), "C");
  Tensor Y = compute(shape, FCompute([=](auto i){
      return HalideIR::Internal::Call::make(
        "vecadd", // <-- What is this name ???
        HalideIR::Internal::Call::PureExtern, // <---- ???
        vecadd_lowered[0], 0);
      } )) ;

  /* The SEGFAUT is in this block */
  auto double_lowered = ({
    Schedule s = create_schedule({Y->op});
    std::unordered_map<Tensor, Buffer> binds;
    auto args = Array<Tensor>({C, Y});
    auto lowered = lower(s, args, "double", binds, config);

  cerr << "DOUBLE_LOWERED" << endl
       << "==============" << endl
       << double_lowered[0]->body << endl;

  auto target = Target::create("llvm");
  auto target_host = Target::create("llvm");
  runtime::Module mod = build(double_lowered, target, target_host, config);

  /* Output LLVM assembly to stdout */
  cout << mod->GetSource("asm") << endl;
  return 0;

[TVM Compiler] "stack overflow" happens on Linux x64 when computation graph is "big"