Relay.build hangs for 5D max_pool3d + reshape + matmul

The following combination of operators (max_pool3d + reshape + matmul) and shapes (1,8,2,2,512) makes relay.build to hang for llvm

import tensorflow as tf
import numpy as np
from tvm import relay

dtype='float32'
input_name = "input"
dshape=(1,8,2,2,512)

mm_shape=(16384,32)
mm_weights = np.random.random_sample(mm_shape).astype(dtype)

with tf.Session() as sess:
    x = tf.placeholder(shape=dshape, dtype=dtype, name=input_name)
    mp1 = tf.nn.max_pool3d(x, ksize=[1,8,1,1,1], padding="SAME", strides=[1,1,1,1,1])
    rsh1 = tf.reshape(mp1, [-1, 16384])
    mm1 = tf.matmul(rsh1, mm_weights, name="matmul")
    graph_def = sess.graph_def
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        graph_def,
        ["matmul",])

mod, params = relay.frontend.from_tensorflow(output_graph_def, layout='NCHW', shape={input_name: dshape})

target = "llvm"

with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(mod, target, params=params)

print("Compilation done")

We cannot run tensor array using build because it produces dynamic shape. We need Relay VM/intepreter for this type of models/examples. But I am not sure why and where it hangs as I didn’t debug.

Sorry, the above code would produce tensor array but only the main function will be used which doesn’t have tensor array.

It looks that the following code would hang as well. We might need to investigate further:

import tensorflow as tf
import numpy as np
from tvm import relay

dtype='float32'
input_name = "input"
dshape=(1,8,2,2,512)


x1 = relay.var("x1", shape=dshape)
x2 = relay.var("x2", shape=(16384, 32))

z0 = relay.transpose(x1, axes=(0, 4, 1, 2, 3))
z1 = relay.nn.max_pool3d(z0, pool_size=[8, 1, 1], padding=[3, 0, 0, 4, 0, 0])
z2 = relay.transpose(z1, axes=[0, 2, 3, 4, 1])
z3 = relay.reshape(z2, newshape=[-1, 16384])
z4 = relay.transpose(x2, axes=[1, 0])
z5 = relay.nn.dense(z3, z4, units=32)
func = relay.Function([x1, x2], z5)
mod = relay.Module.from_expr(func)

print(mod)
with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(mod, target="llvm")

print("Compilation done")

Yes, your example hangs too. I also tried opt_level=0 - relay.build still hangs.

BTW, can you remove 2 space indentation in your code example to simplify copy and paste to python cli

Seems dense pack schedule is difficult for LLVM to compile. Selective simplication of schedule shows one particular vectorize that causes problem. For now, folllowing code can be commented

diff --git a/topi/python/topi/x86/dense.py b/topi/python/topi/x86/dense.py
index b7a3d6d..ad7e404 100644
--- a/topi/python/topi/x86/dense.py
+++ b/topi/python/topi/x86/dense.py
@@ -191,7 +191,7 @@ def _schedule_dense_pack_template(cfg, s, C):
     z, y, x = s[packedB].op.axis
     s[packedB].reorder(z, x, y)
     s[packedB].parallel(z)
-    s[packedB].vectorize(y)
+    # s[packedB].vectorize(y)
     return s

@icemelon9 @yinghai @Laurawly Can you comment on the issue since you were working on https://github.com/apache/incubator-tvm/pull/2561

This line causes the issue https://github.com/apache/incubator-tvm/blob/master/topi/python/topi/x86/dense.py#L194

BTW, Runtime fails for 5D input tensor as well