[TVM Compiler] "stack overflow" happens on Linux x64 when computation graph is "big"

Hi all,
I have a handwritten LSTM model based on https://gist.github.com/grwlf/df81038e660e3acd6343bf67d609045d

import numpy as np
import tvm
import topi
import time
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/mnist/", one_hot=True)

lr = 0.001
num_steps = 1
batch_size = 64
display_step = 100

num_timesteps = 28 * 6
num_input = 28 
num_hidden = 128
num_classes = 10

sizes = [
    (num_input + num_hidden, num_hidden),
    (num_hidden,),
    (num_input + num_hidden, num_hidden),
    (num_hidden,),
    (num_input + num_hidden, num_hidden),
    (num_hidden,),
    (num_input + num_hidden, num_hidden),
    (num_hidden,),
    (num_hidden, num_classes),
    (num_classes,)
]

inits = [
    (np.zeros, 'shape'),
    (np.zeros, 'shape'),
    (np.zeros, 'shape'),
    (np.zeros, 'shape'),
    (np.zeros, 'shape'),
    (np.ones, 'shape'),
    (np.zeros, 'shape'),
    (np.zeros, 'shape'),
    (np.random.normal, 'size'),
    (np.random.normal, 'size')
]

x = tvm.placeholder((batch_size, num_timesteps * num_input), 'float32')
y = tvm.placeholder((batch_size, num_classes), 'float32')
s = tvm.placeholder((batch_size, num_hidden), 'float32')
h = tvm.placeholder((batch_size, num_hidden), 'float32')

weights = [tvm.placeholder(x, 'float32', name="weights") for x in sizes]

xs = topi.split(topi.reshape(x, (batch_size, num_timesteps, num_input)), num_timesteps, axis=1)
xs = [topi.reshape(x, (batch_size, num_input)) for x in xs]
new_s = s
new_h = h

for i in range(num_timesteps):
    inp = topi.concatenate([xs[i], new_h], 1)
    g = topi.tanh(topi.matmul(inp, weights[0]) + weights[1])
    j = topi.sigmoid(topi.matmul(inp, weights[2]) + weights[3])
    f = topi.sigmoid(topi.matmul(inp, weights[4]) + weights[5])
    o = topi.sigmoid(topi.matmul(inp, weights[6]) + weights[7])

    new_s = new_s * f + g * j
    new_h = topi.tanh(new_s) * o

logits = topi.matmul(new_h, weights[8]) + weights[9]

pred = topi.nn.softmax(logits)
correct_pred = topi.equal(topi.argmax(y, 1), topi.argmax(pred, 1))
accuracy = topi.sum(correct_pred.astype('float32')) / batch_size
loss = topi.sum(-topi.sum(y * topi.nn.log_softmax(logits), axis=1)) / batch_size
head = topi.full((1,), 'float32', 1.0)

sched = tvm.create_schedule([loss.op, accuracy.op])
lowered = tvm.lower(sched, [x, y, s, h, loss, accuracy, *weights], simple_mode=True)
print (lowered)
train_model = tvm.build(sched, [x, y, s, h, loss, accuracy, *weights])

The problem happens when num_timesteps = 28 *n, where n >= 6, in this case I got about ~3000+ stages in scheduler.

$export LD_PRELOAD=/path/to/libasan.so
$python3 lstm.py

AddressSanitizer:DEADLYSIGNAL
=================================================================
==3952==ERROR: AddressSanitizer: stack-overflow on address 0x7ffcc2f78ed0 (pc 0x7f8f9a8b03ed bp 0x7ffcc2f79720 sp 0x7ffcc2f78ec0 T0)
    #0 0x7f8f9a8b03ec in operator new(unsigned long) (/usr/lib/x86_64-linux-gnu/libasan.so.5+0xf03ec)
    #1 0x7f8f4bf695a7 in tvm::ir::MutateArray(tvm::Array<HalideIR::Expr, void>, tvm::ir::IRMutator*) (/home//tvm/build/libtvm.so+0x44a5a7)
    #2 0x7f8f4bf61f50 in tvm::ir::IRMutator::Mutate_(HalideIR::Internal::Call const*, HalideIR::Expr const&) (/home//tvm/build/libtvm.so+0x442f50)
    #3 0x7f8f4c0b4261 in tvm::schedule::SchedulePostProc::Mutate_(HalideIR::Internal::Call const*, HalideIR::Expr const&) (/home//tvm/build/libtvm.so+0x595261)
    #4 0x7f8f4bf5b3a6 in std::_Function_handler<HalideIR::Expr (HalideIR::Internal::Call const*, HalideIR::Expr const&, tvm::ir::IRMutator*), tvm::ir::{lambda(HalideIR::Internal::Call const*, HalideIR::Expr const&, tvm::ir::IRMutator*)#18}>::_M_invoke(std::_Any_data const&, HalideIR::Internal::Call const*&&, HalideIR::Expr const&, tvm::ir::IRMutator*&&) (/home//tvm/build/libtvm.so+0x43c3a6)
    #5 0x7f8f4bf64d12 in std::_Function_handler<HalideIR::Expr (tvm::NodeRef const&, HalideIR::Expr const&, tvm::ir::IRMutator*), tvm::IRFunctor<HalideIR::Expr (tvm::NodeRef const&, HalideIR::Expr const&, tvm::ir::IRMutator*)>::set_dispatch<HalideIR::Internal::Call>(std::function<HalideIR::Expr (HalideIR::Internal::Call const*, HalideIR::Expr const&, tvm::ir::IRMutator*)>)::{lambda(tvm::NodeRef const&, HalideIR::Expr const&, tvm::ir::IRMutator*)#1}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, HalideIR::Expr const&, tvm::ir::IRMutator*&&) (/home//tvm/build/libtvm.so+0x445d12)
    #6 0x7f8f4bf18b59 in tvm::IRFunctor<HalideIR::Expr (tvm::NodeRef const&, HalideIR::Expr const&, tvm::ir::IRMutator*)>::operator()(tvm::NodeRef const&, HalideIR::Expr const&, tvm::ir::IRMutator*) const (/home//tvm/build/libtvm.so+0x3f9b59)
    #7 0x7f8f4bf193dc in tvm::ir::IRMutator::Mutate(HalideIR::Expr) (/home//tvm/build/libtvm.so+0x3fa3dc)
    #8 0x7f8f4bf5f611 in tvm::ir::IRMutator::Mutate_(HalideIR::Internal::Mul const*, HalideIR::Expr const&) (/home//tvm/build/libtvm.so+0x440611)
    #9 0x7f8f4bf5b4f6 in std::_Function_handler<HalideIR::Expr (HalideIR::Internal::Mul const*, HalideIR::Expr const&, tvm::ir::IRMutator*), tvm::ir::{lambda(HalideIR::Internal::Mul const*, HalideIR::Expr const&, tvm::ir::IRMutator*)#21}>::_M_invoke(std::_Any_data const&, HalideIR

It looks like the problem occurs because of recursion in Mutators*, and could be “workarounded” by increasing stack size for example
$ulimit -s unlimited
on Linux machine.

So, can someone please tell me, are there any options to split “big” computation graph to the set of small “compute nodes”, compile them separately, then combine them together by calling one from other and execute to avoid stack overflow without any “workarounds” ?
As far as I understood I can load compiled modules separately at the runtime

  myadd = tvm.module.load("myadd.so")
  mymul = tvm.module.load("mymul.so")

But, are there any ways to call function defined in compiled module from TVM compute, may be by packed func machinery “call_packed”, or may be by “call extern” ?
Original question [TVM] How to insert a call to LoweredFunc into another LoweredFunc?
Thanks.

I think it is generally not a very good idea to directly unroll a deep LSTM graph. Instead, it would be much better if we explicit do a loop(in the for of num_timesteps).

The scan https://docs.tvm.ai/tutorials/language/scan.html#sphx-glr-tutorials-language-scan-py might be helpful here if we are interested in a single kernel approach

@tqchen
thanks for the answer.