Hi. I’m experimenting with Call nodes. In particular, I’d like to construct a code which would call a function from another function in C++.
This area is happened to be beyond the scope of examples, so my attempt below is more like a guess, and it doesn’t work, unfortunately. Could you please review the code and help me to fix the cause of the SegFault?
The below program is expected to define a vecadd
function for adding two tensors and double
function to take one argument and call vecadd
to double it.
Full sources are here https://gist.github.com/grwlf/38217ceb345fb4c106e52e92fc6706d3
using namespace std;
using namespace tvm;
int main()
{
BuildConfig config = build_config();
auto n = var("n");
Array<Expr> shape = {n};
Tensor A = placeholder(shape, Float(32), "A");
Tensor B = placeholder(shape, Float(32), "B");
Tensor X = compute(shape, FCompute([=](auto i){ return A(i) + B(i); } )) ;
auto vecadd_lowered = ({
Schedule s = create_schedule({X->op});
std::unordered_map<Tensor, Buffer> binds;
auto args = Array<Tensor>({A, B, X});
auto lowered = lower(s, args, "vecadd", binds, config);
lowered;
});
cerr << "VECADD_LOWERED" << endl
<< "==============" << endl
<< vecadd_lowered[0]->body << endl;
Tensor C = placeholder(shape, Float(32), "C");
Tensor Y = compute(shape, FCompute([=](auto i){
return HalideIR::Internal::Call::make(
Float(32),
"vecadd", // <-- What is this name ???
Array<Expr>({C(i),C(i)}),
HalideIR::Internal::Call::PureExtern, // <---- ???
vecadd_lowered[0], 0);
} )) ;
/* The SEGFAUT is in this block */
auto double_lowered = ({
Schedule s = create_schedule({Y->op});
std::unordered_map<Tensor, Buffer> binds;
auto args = Array<Tensor>({C, Y});
auto lowered = lower(s, args, "double", binds, config);
lowered;
});
cerr << "DOUBLE_LOWERED" << endl
<< "==============" << endl
<< double_lowered[0]->body << endl;
auto target = Target::create("llvm");
auto target_host = Target::create("llvm");
runtime::Module mod = build(double_lowered, target, target_host, config);
/* Output LLVM assembly to stdout */
cout << mod->GetSource("asm") << endl;
return 0;
}