I have a GEMM C=A*B
(where the shapes are A [2048, 256], B[256, 2048], C[2048, 2048])
and the following schedule
i_outer, j_outer, i_inner, j_inner = s[C].tile(C.op.axis[0], C.op.axis[1], 64, 64) all = s[C].fuse(i_inner, j_inner) i, j = s[C].split(all, nparts=512)
I create a tenorize intrinsic that computes a 64x64x256 GEMM.
gemv = intrin_gemv(64, 64, K)
I am surprised to find that both
s[C].tensorize(i, gemv)
and
s[C].tensorize(j, gemv)
can be successfully matched and lowered.
The former produces
produce C { for (i.outer, 0, 32) { for (j.outer, 0, 32) { gemv_update(tvm_address_of(C[(((i.outer*2048) + j.outer)*64)]), tvm_address_of(A[(i.outer*16384)]), tvm_address_of(B[(j.outer*64)])) } } }
while the latter produces
produce C { for (i.outer, 0, 32) { for (j.outer, 0, 32) { for (i.inner.j.inner.fused.outer, 0, 512) { gemv_update(tvm_address_of(C[(((i.outer*2048) + j.outer)*64)]), tvm_address_of(A[(i.outer*16384)]), tvm_address_of(B[(j.outer*64)])) } } } }
Questions.
- Is
s[C].tensorize(i, gemv)
supposed to work?
The loop nest staring at i
computes the same tensor domain as the intrinsic function does. But the loop structures are very different.
- Is
s[C].tensorize(j, gemv)
supposed to work?
This seems a bug to me.
Thanks.
Yuan
P.S. the completed test case
import tvm
import numpy as np
M = 2048
N = 2048
K = 256
k = tvm.reduce_axis((0, K), 'k')
A = tvm.placeholder((M, K), name='A')
B = tvm.placeholder((K, N), name='B')
C = tvm.compute(
(M, N),
lambda i, j: tvm.sum(A[i, k]* B[k, j], axis=k),
name='C'
)
def intrin_gemv(m, n, kk):
a = tvm.placeholder((m, kk), name='a')
b = tvm.placeholder((kk, n), name='b')
k = tvm.reduce_axis((0, kk), name='k')
c = tvm.compute((m,n), lambda i, j:
tvm.sum(a[i, k] * b[k, j], axis=k), name='c')
Ab = tvm.decl_buffer(a.shape, a.dtype,
name="A",
offset_factor=1,
strides=[tvm.var("sa"), 1])
Bb = tvm.decl_buffer(b.shape, b.dtype,
name="B",
offset_factor=1,
strides=[tvm.var("sb"), 1])
Cb = tvm.decl_buffer(c.shape, c.dtype,
name="C",
offset_factor=1,
strides=[tvm.var("sc"), 1])
def intrin_func(ins, outs):
aa, bb = ins
cc = outs[0]
def _body():
ib = tvm.ir_builder.create()
ib.emit(tvm.call_extern("int32", "gemv_update",
cc.access_ptr("w"),
aa.access_ptr("r"),
bb.access_ptr("r")))
return ib.get()
def _reduce_reset():
ib = tvm.ir_builder.create()
ib.emit(tvm.call_extern("int32", "gemv_reset", cc.access_ptr("w")))
return ib.get()
def _reduce_update():
return _body()
return _body(), _reduce_reset(), _reduce_update()
with tvm.build_config(offset_factor=1):
return tvm.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
# Schedule 1
s = tvm.create_schedule(C.op)
i_outer, j_outer, i_inner, j_inner = s[C].tile(C.op.axis[0], C.op.axis[1], 64, 64)
all = s[C].fuse(i_inner, j_inner)
i, j = s[C].split(all, nparts=512)
print(tvm.lower(s, [A, B, C], simple_mode=True))
gemv = intrin_gemv(64, 64, K)
s[C].tensorize(i, gemv) # The only difference between schedule 1 and 2.
print(tvm.lower(s, [A, B, C], simple_mode=True))
# Schedule 2
s = tvm.create_schedule(C.op)
i_outer, j_outer, i_inner, j_inner = s[C].tile(C.op.axis[0], C.op.axis[1], 64, 64)
all = s[C].fuse(i_inner, j_inner)
i, j = s[C].split(all, nparts=512)
gemv = intrin_gemv(64, 64, K)
s[C].tensorize(j, gemv) # The only difference between schedule 1 and 2.
print(tvm.lower(s, [A, B, C], simple_mode=True))
Output
produce C {
for (i.outer, 0, 32) {
for (j.outer, 0, 32) {
gemv_update(tvm_address_of(C[(((i.outer*2048) + j.outer)*64)]), tvm_address_of(A[(i.outer*16384)]), tvm_address_of(B[(j.outer*64)]))
}
}
}
produce C {
for (i.outer, 0, 32) {
for (j.outer, 0, 32) {
for (i.inner.j.inner.fused.outer, 0, 512) {
gemv_update(tvm_address_of(C[(((i.outer*2048) + j.outer)*64)]), tvm_address_of(A[(i.outer*16384)]), tvm_address_of(B[(j.outer*64)]))
}
}
}
}