I created a min-sample for symbolic expr issue:
import tvm
import topi
dshape = (tvm.var("n"), 72, 96)
target = "cuda"
def compute(data):
oshape = data.shape
out = tvm.compute(oshape, lambda i, j, k: data[i, j, k] * 10)
return out
def schedule(s, out):
n, m, _ = s[out].op.axis
bn_z, n = s[out].split(n, 32)
bn_y, bn_x = s[out].split(n, 8)
tm_z, m = s[out].split(m, 12)
tm_y, tm_x = s[out].split(m, 1)
s[out].bind(bn_z, tvm.thread_axis("blockIdx.z"))
s[out].bind(bn_y, tvm.thread_axis("blockIdx.y"))
s[out].bind(bn_x, tvm.thread_axis("blockIdx.x"))
s[out].bind(tm_z, tvm.thread_axis("threadIdx.z"))
s[out].bind(tm_y, tvm.thread_axis("threadIdx.y"))
s[out].bind(tm_x, tvm.thread_axis("threadIdx.x"))
return s
d = tvm.placeholder(dshape, name="data")
out = compute(d)
s = tvm.create_schedule(out.op)
s = schedule(s, out)
f = tvm.build(s, [d, out], target)
Lower stmt printed in tvm.build:
produce compute {
// attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = ((n + 31)/32)
// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 4
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 8
// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 6
// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 12
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1
if (((blockIdx.z < (((n - 32)/32) + 1)) && (blockIdx.z < (((n + 31)/32) - 1)))) {
for (k, 0, 96) {
compute[((((((blockIdx.z*221184) + (blockIdx.y*55296)) + (blockIdx.x*6912)) + (threadIdx.z*1152)) + (threadIdx.y*96)) + k)] = (data[((((((blockIdx.z*221184) + (blockIdx.y*55296)) + (blockIdx.x*6912)) + (threadIdx.z*1152)) + (threadIdx.y*96)) + k)]*10f)
}
} else {
for (k, 0, 96) {
if (((((blockIdx.z*32) + (blockIdx.y*8)) + blockIdx.x) < n)) {
if (((((blockIdx.z*32) + (blockIdx.y*8)) + blockIdx.x) < n)) {
compute[((((((blockIdx.z*221184) + (blockIdx.y*55296)) + (blockIdx.x*6912)) + (threadIdx.z*1152)) + (threadIdx.y*96)) + k)] = (data[((((((blockIdx.z*221184) + (blockIdx.y*55296)) + (blockIdx.x*6912)) + (threadIdx.z*1152)) + (threadIdx.y*96)) + k)]*10f)
}
}
}
}
}
The two problems are: 1) if (((blockIdx.z < (((n - 32)/32) + 1)) && (blockIdx.z < (((n + 31)/32) - 1))))
is not simplified. 2) Several if statements are under for loop of k, which can be moved up to reduce the number of executions.
@tqchen Would you think this can be improved by simplifier, or other parts of tvm?