Domain of tvm.thread_axis


#1

tvm.thread_axis() accepts an optional domain argument. Could anyone please explain how the domain is used?

I tried the following example and do not see how it affects the schedule.

M = 7*5*2
N = 11*9*2

A = tvm.placeholder((M, N), name='A')

B = tvm.compute((M, N), lambda i,j: 3.14 * A[i, j], name='B')

s = tvm.create_schedule(B.op)

Mblock = 5*2
Nblock = 9*2

i_outer, j_outer, i_inner, j_inner = s[B].tile(B.op.axis[0], B.op.axis[1], Mblock, Nblock)

block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_x = tvm.thread_axis((0,5), "threadIdx.x")
thread_y = tvm.thread_axis((0,9), "threadIdx.y")

s[B].bind(i_outer, block_y)
s[B].bind(j_outer, block_x)
s[B].bind(i_inner, thread_y)
s[B].bind(j_inner, thread_x)

Output

produce B {
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 7
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 11
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 10
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 18
  B[(((((blockIdx.y*110) + blockIdx.x) + (threadIdx.y*11))*18) + threadIdx.x)] = (A[(((((blockIdx.y*110) + blockIdx.x) + (threadIdx.y*11))*18) + threadIdx.x)]*3.140000f)
}

#2

The domain is ignored when a special thread tag is used (e.g., “blockIdx.x,” “threadIdx.x”) https://docs.tvm.ai/api/python/tvm.html#tvm.thread_axis.


#3

Thanks! Now I understand what it means.

Should

dom (Range or str) – The domain of iteration When str is passed, dom is set to None and str is used as tag

be interpreted as

dom (Range) – The domain of iteration. When tag is passed, dom is set to None.

?