Thanks @tqchen. I realized that our usecase require us to compile everytime, so in this case AOT will also suffer the same fate. We checked by dumping the “so” files, the slowdown just shifted to “so” creation.
However, I might have been able to pinpoint atleast a portion of the problem. I am able to regenerate the problem with just max_pool2d. The test case is
import tvm
from tvm import relay
from tvm.relay.testing import create_workload
from tvm.contrib import graph_runtime
import time
def compile_graph(simple_net, t, config=None):
target = tvm.target.create(t)
func = relay.Function(relay.ir_pass.free_vars(simple_net), simple_net)
func, params = create_workload(func)
if config is None:
config = {"opt_level" : 3}
with relay.build_config(**config):
artifacts = relay.build(func, target=target, params=params)
if t == 'llvm':
ctx = tvm.cpu()
graph, libs, params = artifacts
module = graph_runtime.create(graph, libs, ctx)
def test_pool(batch_size, in_channel, height, width, kh, kw, stride, padding,
dtype):
data_shape = (1, height, width, in_channel)
data = relay.var("data", shape=data_shape, dtype=dtype)
simple_net = relay.nn.max_pool2d(data=data,
pool_size=(kh, kw),
strides=(stride, stride),
padding=padding,
layout='NHWC')
compile_graph(simple_net, 'llvm')
start = time.time()
test_pool(batch_size=1, in_channel=288, height=35, width=35, kh=3, kw=3,
stride=2, padding=(0, 0, 0, 0), dtype="float32")
end = time.time()
print("Float32 took", str(end - start), "seconds")
start = time.time()
test_pool(batch_size=1, in_channel=288, height=35, width=35, kh=3, kw=3,
stride=2, padding=(0, 0, 0, 0), dtype="float16")
end = time.time()
print("Float16 took", str(end - start), "seconds")
The above test case on my machine takes
Float32 --> 2 seconds
Float16 --> more than 10 min
If I use the original compute as the default schedule for pool2d, then atleast for my test case, the compilation time reduces to mere milliseconds.
Vectorization might be the main reason here.