Hi all,
I have a handwritten LSTM model based on https://gist.github.com/grwlf/df81038e660e3acd6343bf67d609045d
import numpy as np
import tvm
import topi
import time
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/mnist/", one_hot=True)
lr = 0.001
num_steps = 1
batch_size = 64
display_step = 100
num_timesteps = 28 * 6
num_input = 28
num_hidden = 128
num_classes = 10
sizes = [
(num_input + num_hidden, num_hidden),
(num_hidden,),
(num_input + num_hidden, num_hidden),
(num_hidden,),
(num_input + num_hidden, num_hidden),
(num_hidden,),
(num_input + num_hidden, num_hidden),
(num_hidden,),
(num_hidden, num_classes),
(num_classes,)
]
inits = [
(np.zeros, 'shape'),
(np.zeros, 'shape'),
(np.zeros, 'shape'),
(np.zeros, 'shape'),
(np.zeros, 'shape'),
(np.ones, 'shape'),
(np.zeros, 'shape'),
(np.zeros, 'shape'),
(np.random.normal, 'size'),
(np.random.normal, 'size')
]
x = tvm.placeholder((batch_size, num_timesteps * num_input), 'float32')
y = tvm.placeholder((batch_size, num_classes), 'float32')
s = tvm.placeholder((batch_size, num_hidden), 'float32')
h = tvm.placeholder((batch_size, num_hidden), 'float32')
weights = [tvm.placeholder(x, 'float32', name="weights") for x in sizes]
xs = topi.split(topi.reshape(x, (batch_size, num_timesteps, num_input)), num_timesteps, axis=1)
xs = [topi.reshape(x, (batch_size, num_input)) for x in xs]
new_s = s
new_h = h
for i in range(num_timesteps):
inp = topi.concatenate([xs[i], new_h], 1)
g = topi.tanh(topi.matmul(inp, weights[0]) + weights[1])
j = topi.sigmoid(topi.matmul(inp, weights[2]) + weights[3])
f = topi.sigmoid(topi.matmul(inp, weights[4]) + weights[5])
o = topi.sigmoid(topi.matmul(inp, weights[6]) + weights[7])
new_s = new_s * f + g * j
new_h = topi.tanh(new_s) * o
logits = topi.matmul(new_h, weights[8]) + weights[9]
pred = topi.nn.softmax(logits)
correct_pred = topi.equal(topi.argmax(y, 1), topi.argmax(pred, 1))
accuracy = topi.sum(correct_pred.astype('float32')) / batch_size
loss = topi.sum(-topi.sum(y * topi.nn.log_softmax(logits), axis=1)) / batch_size
head = topi.full((1,), 'float32', 1.0)
sched = tvm.create_schedule([loss.op, accuracy.op])
lowered = tvm.lower(sched, [x, y, s, h, loss, accuracy, *weights], simple_mode=True)
print (lowered)
train_model = tvm.build(sched, [x, y, s, h, loss, accuracy, *weights])
The problem happens when num_timesteps = 28 *n, where n >= 6, in this case I got about ~3000+ stages in scheduler.
$export LD_PRELOAD=/path/to/libasan.so
$python3 lstm.py
AddressSanitizer:DEADLYSIGNAL
=================================================================
==3952==ERROR: AddressSanitizer: stack-overflow on address 0x7ffcc2f78ed0 (pc 0x7f8f9a8b03ed bp 0x7ffcc2f79720 sp 0x7ffcc2f78ec0 T0)
#0 0x7f8f9a8b03ec in operator new(unsigned long) (/usr/lib/x86_64-linux-gnu/libasan.so.5+0xf03ec)
#1 0x7f8f4bf695a7 in tvm::ir::MutateArray(tvm::Array<HalideIR::Expr, void>, tvm::ir::IRMutator*) (/home//tvm/build/libtvm.so+0x44a5a7)
#2 0x7f8f4bf61f50 in tvm::ir::IRMutator::Mutate_(HalideIR::Internal::Call const*, HalideIR::Expr const&) (/home//tvm/build/libtvm.so+0x442f50)
#3 0x7f8f4c0b4261 in tvm::schedule::SchedulePostProc::Mutate_(HalideIR::Internal::Call const*, HalideIR::Expr const&) (/home//tvm/build/libtvm.so+0x595261)
#4 0x7f8f4bf5b3a6 in std::_Function_handler<HalideIR::Expr (HalideIR::Internal::Call const*, HalideIR::Expr const&, tvm::ir::IRMutator*), tvm::ir::{lambda(HalideIR::Internal::Call const*, HalideIR::Expr const&, tvm::ir::IRMutator*)#18}>::_M_invoke(std::_Any_data const&, HalideIR::Internal::Call const*&&, HalideIR::Expr const&, tvm::ir::IRMutator*&&) (/home//tvm/build/libtvm.so+0x43c3a6)
#5 0x7f8f4bf64d12 in std::_Function_handler<HalideIR::Expr (tvm::NodeRef const&, HalideIR::Expr const&, tvm::ir::IRMutator*), tvm::IRFunctor<HalideIR::Expr (tvm::NodeRef const&, HalideIR::Expr const&, tvm::ir::IRMutator*)>::set_dispatch<HalideIR::Internal::Call>(std::function<HalideIR::Expr (HalideIR::Internal::Call const*, HalideIR::Expr const&, tvm::ir::IRMutator*)>)::{lambda(tvm::NodeRef const&, HalideIR::Expr const&, tvm::ir::IRMutator*)#1}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, HalideIR::Expr const&, tvm::ir::IRMutator*&&) (/home//tvm/build/libtvm.so+0x445d12)
#6 0x7f8f4bf18b59 in tvm::IRFunctor<HalideIR::Expr (tvm::NodeRef const&, HalideIR::Expr const&, tvm::ir::IRMutator*)>::operator()(tvm::NodeRef const&, HalideIR::Expr const&, tvm::ir::IRMutator*) const (/home//tvm/build/libtvm.so+0x3f9b59)
#7 0x7f8f4bf193dc in tvm::ir::IRMutator::Mutate(HalideIR::Expr) (/home//tvm/build/libtvm.so+0x3fa3dc)
#8 0x7f8f4bf5f611 in tvm::ir::IRMutator::Mutate_(HalideIR::Internal::Mul const*, HalideIR::Expr const&) (/home//tvm/build/libtvm.so+0x440611)
#9 0x7f8f4bf5b4f6 in std::_Function_handler<HalideIR::Expr (HalideIR::Internal::Mul const*, HalideIR::Expr const&, tvm::ir::IRMutator*), tvm::ir::{lambda(HalideIR::Internal::Mul const*, HalideIR::Expr const&, tvm::ir::IRMutator*)#21}>::_M_invoke(std::_Any_data const&, HalideIR
It looks like the problem occurs because of recursion in Mutators*, and could be “workarounded” by increasing stack size for example
$ulimit -s unlimited
on Linux machine.
So, can someone please tell me, are there any options to split “big” computation graph to the set of small “compute nodes”, compile them separately, then combine them together by calling one from other and execute to avoid stack overflow without any “workarounds” ?
As far as I understood I can load compiled modules separately at the runtime
myadd = tvm.module.load("myadd.so")
mymul = tvm.module.load("mymul.so")
But, are there any ways to call function defined in compiled module from TVM compute, may be by packed func machinery “call_packed”, or may be by “call extern” ?
Original question [TVM] How to insert a call to LoweredFunc into another LoweredFunc?
Thanks.