[tensorize] Which use case is correct?

I have a GEMM C=A*B(where the shapes are A [2048, 256], B[256, 2048], C[2048, 2048])

and the following schedule

i_outer, j_outer, i_inner, j_inner = s[C].tile(C.op.axis[0], C.op.axis[1], 64, 64)
all = s[C].fuse(i_inner, j_inner)
i, j = s[C].split(all, nparts=512)

I create a tenorize intrinsic that computes a 64x64x256 GEMM.

gemv = intrin_gemv(64, 64, K)

I am surprised to find that both

 s[C].tensorize(i, gemv)

and

s[C].tensorize(j, gemv)

can be successfully matched and lowered.

The former produces

produce C {
  for (i.outer, 0, 32) {
    for (j.outer, 0, 32) {
      gemv_update(tvm_address_of(C[(((i.outer*2048) + j.outer)*64)]), tvm_address_of(A[(i.outer*16384)]), tvm_address_of(B[(j.outer*64)]))
    }
  }
}

while the latter produces

produce C {
  for (i.outer, 0, 32) {
    for (j.outer, 0, 32) {
      for (i.inner.j.inner.fused.outer, 0, 512) {
        gemv_update(tvm_address_of(C[(((i.outer*2048) + j.outer)*64)]), tvm_address_of(A[(i.outer*16384)]), tvm_address_of(B[(j.outer*64)]))
      }
    }
  }
}

Questions.

  1. Is s[C].tensorize(i, gemv) supposed to work?

The loop nest staring at i computes the same tensor domain as the intrinsic function does. But the loop structures are very different.

  1. Is s[C].tensorize(j, gemv) supposed to work?

This seems a bug to me.

Thanks.

Yuan

P.S. the completed test case

import tvm
import numpy as np

M = 2048
N = 2048
K = 256

k = tvm.reduce_axis((0, K), 'k')

A = tvm.placeholder((M, K), name='A')
B = tvm.placeholder((K, N), name='B')

C = tvm.compute(
    (M, N),
    lambda i, j:  tvm.sum(A[i, k]* B[k, j], axis=k),
    name='C'
)

def intrin_gemv(m, n, kk):
    a = tvm.placeholder((m, kk), name='a')
    b = tvm.placeholder((kk, n), name='b')
    k = tvm.reduce_axis((0, kk), name='k')
    c = tvm.compute((m,n), lambda i, j:
        tvm.sum(a[i, k] * b[k, j], axis=k), name='c')
    Ab = tvm.decl_buffer(a.shape, a.dtype,
                         name="A",
                         offset_factor=1,
                         strides=[tvm.var("sa"), 1])
    Bb = tvm.decl_buffer(b.shape, b.dtype,
                         name="B",
                         offset_factor=1,
                         strides=[tvm.var("sb"), 1])
    Cb = tvm.decl_buffer(c.shape, c.dtype,
                         name="C",
                         offset_factor=1,
                         strides=[tvm.var("sc"), 1])
    def intrin_func(ins, outs):
        aa, bb = ins
        cc = outs[0]
        def _body():
            ib = tvm.ir_builder.create()
            ib.emit(tvm.call_extern("int32", "gemv_update",
                                    cc.access_ptr("w"),
                                    aa.access_ptr("r"),
                                    bb.access_ptr("r")))
            return ib.get()
        def _reduce_reset():
            ib = tvm.ir_builder.create()
            ib.emit(tvm.call_extern("int32", "gemv_reset", cc.access_ptr("w")))
            return ib.get()
        def _reduce_update():
            return _body()
        return _body(), _reduce_reset(), _reduce_update()
    with tvm.build_config(offset_factor=1):
        return tvm.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
    
# Schedule 1
s = tvm.create_schedule(C.op)

i_outer, j_outer, i_inner, j_inner = s[C].tile(C.op.axis[0], C.op.axis[1], 64, 64)
all = s[C].fuse(i_inner, j_inner)
i, j = s[C].split(all, nparts=512)

print(tvm.lower(s, [A, B, C], simple_mode=True))

gemv = intrin_gemv(64, 64, K)
s[C].tensorize(i, gemv)        # The only difference between schedule 1 and 2.

print(tvm.lower(s, [A, B, C], simple_mode=True))


# Schedule 2
s = tvm.create_schedule(C.op)

i_outer, j_outer, i_inner, j_inner = s[C].tile(C.op.axis[0], C.op.axis[1], 64, 64)
all = s[C].fuse(i_inner, j_inner)
i, j = s[C].split(all, nparts=512)

gemv = intrin_gemv(64, 64, K)
s[C].tensorize(j, gemv)        # The only difference between schedule 1 and 2.

print(tvm.lower(s, [A, B, C], simple_mode=True))

Output

produce C {
  for (i.outer, 0, 32) {
    for (j.outer, 0, 32) {
      gemv_update(tvm_address_of(C[(((i.outer*2048) + j.outer)*64)]), tvm_address_of(A[(i.outer*16384)]), tvm_address_of(B[(j.outer*64)]))
    }
  }
}

produce C {
  for (i.outer, 0, 32) {
    for (j.outer, 0, 32) {
      for (i.inner.j.inner.fused.outer, 0, 512) {
        gemv_update(tvm_address_of(C[(((i.outer*2048) + j.outer)*64)]), tvm_address_of(A[(i.outer*16384)]), tvm_address_of(B[(j.outer*64)]))
      }
    }
  }
}
1 Like

I think, this is not a bug but it is by design.

In the second case of schedule, i is considered to have range (min=0, extent=1).

You can say that, inside that loop nest of i, it is considered as a constant. In the second case, compute op pattern then matches with what you have in gemv declaration, and tensorize is able to match with your gemv pattern.

@umangyadav Thanks for your reply.

If i is considered to have range (min=0, extend=1), then would the lowered code not have for (i.inner.j.inner.fused.outer, 0, 1) instead of for (i.inner.j.inner.fused.outer, 0, 512)?

The domain of the tensor computed by the intrinsic function is the same in both cases, right? If so, then one of two cases must be wrong, I think. What am I missing?

I did some further study and figured out what happens in case 2.

The following is a simplied test case which illustrates the same issue.

import tvm
import numpy as np
        
M = 64
N = 64

A = tvm.placeholder((M, N), name='A')
C = tvm.compute(
    (M, N),
    lambda i, j: A[i,j], 
    name='C'
)

def intrin_copy(m, n):
    a = tvm.placeholder((m, n), name='a')
    c = tvm.compute((m,n), lambda i, j: a[i,j], name='c')
    Ab = tvm.decl_buffer(a.shape, a.dtype,
                         name="A",
                         offset_factor=1,
                         strides=[tvm.var("sa"), 1])
    Cb = tvm.decl_buffer(c.shape, c.dtype,
                         name="C",
                         offset_factor=1,
                         strides=[tvm.var("sc"), 1])
    def intrin_func(ins, outs):
        aa = ins[0]
        cc = outs[0]
        def _body():
            ib = tvm.ir_builder.create()
            ib.emit(tvm.call_extern("int32", "copy",
                                    cc.access_ptr("w"),
                                    aa.access_ptr("r"),
                                    m, n))
            return ib.get()
        return _body()
    with tvm.build_config(offset_factor=1):
        return tvm.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, c: Cb})
    
# Schedule 1
s = tvm.create_schedule(C.op)

all = s[C].fuse(C.op.axis[0], C.op.axis[1])
i, j = s[C].split(all, nparts=512)

print("No Tensorize")
print(tvm.lower(s, [A, C], simple_mode=True))

copy = intrin_copy(64, 64)
s[C].tensorize(i, copy)     # The only difference between schedule 1 and 2.

print("Case 1")
print(tvm.lower(s, [A, C], simple_mode=True))


# Schedule 2
s = tvm.create_schedule(C.op)

all = s[C].fuse(C.op.axis[0], C.op.axis[1])
i, j = s[C].split(all, nparts=512)

copy = intrin_copy(64, 64)
s[C].tensorize(j, copy)     # The only difference between schedule 1 and 2.

print("Case 2")
print(tvm.lower(s, [A, C], simple_mode=True))

Output

No Tensorize
produce C {
  for (i.j.fused.outer, 0, 512) {
    for (i.j.fused.inner, 0, 8) {
      C[((i.j.fused.outer*8) + i.j.fused.inner)] = A[((i.j.fused.outer*8) + i.j.fused.inner)]
    }
  }
}

Case 1
produce C {
  copy(tvm_address_of(C[0]), tvm_address_of(A[0]), 64, 64)
}

Case 2
produce C {
  for (i.j.fused.outer, 0, 512) {
    copy(tvm_address_of(C[0]), tvm_address_of(A[0]), 64, 64)
  }
}

Tensorizer needs to calculate the domain of the computed tensor at the specified scope. It does this by calling the PassUpDomain() function and uses the ranges of the leaf IterVars. This is similar to the bound inference (discussed here). PassUpDomain(), however, cannot model the domain precisely (see code ) in this case, and has to make a conservative assumption that claims the whole tensor is computed.

I do not fully understand the ramification here. It does seem to make tensoriser tricker to use.

  1. One has to understand when/how the conservative domain is computed. For example, in the above case, how would one know intrin_copy(64, 64) should be used instead of the more intuitive intrin_copy(1, 8) which would result in pattern-mimatch?
  2. Redundant computation can be introduced, as seen in the output of case 2.
produce C {
  for (i.j.fused.outer, 0, 512) {
    copy(tvm_address_of(C[0]), tvm_address_of(A[0]), 64, 64)
  }
}

I am also running into a problem with PassUpDomain() while trying to fuse and split 2 axes. Looks like this task has been set aside for the last 2 years. Maybe we should create an issue or a bug report, unless one already exists?

1 Like