How to schedule different shapes in contiguous stages

Hi, I’m writing a schedule about NHWC_Conv2d with tensor core using wmma intrinsic. I got a problem here.

I need to split wmma_m and wmma_n axis from n, h, w and c axises. wmma_n can be splitted from c axis. To support batch=1, I want to split wmma_m from w axis. wmma_m can be 8, 16 or 32, but w is usually not a multiple of wmma_m. So I tried to pad w to padded_w.

Here is my compute and schedule code, where I pad input to input_pad_w, get conv_padded with GEMM and unpad conv_padded to conv. The problem is that shapes of conv and conv_padded are different and conv_padded is marked inline. So in the lower code, only part of the conv_padded is calculated, which lead to the failure of the tensorize.

def nhwc_tensorcore_int8_w(cfg, Input, Filter, strides, padding, dilation=1):
    assert isinstance(strides, int) or len(strides) == 2
    assert isinstance(dilation, int) or len(dilation) == 2

    stride_h = stride_w = strides
    dilation_h = dilation_w = dilation

    batch, in_height, in_width, in_channel = get_const_tuple(Input.shape)
    kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape)
    assert (in_channel % 16 == 0 and num_filter % 8 == 0)

    # compute output shape
    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)

    # pad at the side
    if pad_top or pad_left:
        input_pad_side = nn.pad(Input, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="input_pad_side")
    else:
        input_pad_side = Input

    # compute gemm shape
    K = kernel_h * kernel_w * in_channel
    input_shape = (batch, out_height, out_width, K)

    # im2col
    if kernel_h == 1 and kernel_w == 1:
        input_im2col = te.compute(input_shape, lambda b, h, w, y:
                                    input_pad_side[b, stride_h * h, stride_w * w, y],
                                  name="input_im2col")
    else:
        input_im2col = te.compute(input_shape, lambda b, h, w, y:
                                    input_pad_side[b,
                                                   stride_h * h + (y // in_channel) // kernel_w,
                                                   stride_w * w + (y // in_channel) % kernel_w,
                                                   y % in_channel],
                                  name="input_im2col")

    # pad w
    cfg.define_knob("MTile", [8, 16, 32, 64])
    padding_factor = cfg["MTile"].val
    pad_w = 0
    if out_width % padding_factor != 0:
        pad_w = padding_factor - (out_width % padding_factor)
    if pad_w != 0:
        input_pad_w = nn.pad(input_im2col, [0, 0, 0, 0], [0, 0, pad_w, 0], name="input_pad_w")
    else:
        input_pad_w = input_im2col

    # GEMM
    padded_w = out_width + pad_w
    rx = te.reduce_axis((0, kernel_h), name='rx')
    ry = te.reduce_axis((0, kernel_w), name='ry')
    rc = te.reduce_axis((0, in_channel), name='rc')

    conv_padded = te.compute((batch, out_height, padded_w, num_filter),
                             lambda b, h, w, y: te.sum(
                                 input_pad_w[b, h, w, (rx * kernel_w + ry) * in_channel + rc].astype("int32") * \
                                 Filter[rx, ry, rc, y].astype("int32"), axis=[rx, ry, rc]),
                             name="conv_padded")

    # unpad
    conv = te.compute((batch, out_height, out_width, num_filter),
                      lambda b, h, w, c: conv_padded[b, h, w, c],
                      name="conv", tag="conv2d_nhwc_tensorcore_int8_w")

    return conv


def schedule_nhwc_tensorcore_int8_w(cfg, s, conv):
    # todo: support unpad
    batch, out_height, out_width, num_filter = get_const_tuple(conv.shape)
    conv_padded = s[conv].op.input_tensors[0]
    rx, ry, rc = s[conv_padded].op.reduce_axis
    input_pad_w, filter = s[conv_padded].op.input_tensors
    if input_pad_w.op.name == "input_pad_w":
        input_im2col = s[input_pad_w].op.input_tensors[0]
    else:
        input_im2col = input_pad_w
    input_pad_side = s[input_im2col].op.input_tensors[0]

    # compute inline
    s[conv_padded].compute_inline()
    if input_pad_w.op.name == "input_pad_w":
        s[input_pad_w].compute_inline()
    s[input_im2col].compute_inline()
    if input_pad_side.op.name == "input_pad_side":
        s[input_pad_side].compute_inline()

    # Designate the memory hierarchy
    AS = s.cache_read(input_pad_w, "shared", [conv_padded])
    WS = s.cache_read(filter, "shared", [conv_padded])
    AF = s.cache_read(AS, "wmma.matrix_a", [conv_padded])
    WF = s.cache_read(WS, "wmma.matrix_b", [conv_padded])
    CF_padded = s.cache_write(conv_padded, "wmma.accumulator")
    CS_padded = s.cache_read(CF_padded, "shared", [conv_padded])

    if conv.op in s.outputs:
        output = conv
        CS = s.cache_write(conv, "shared")
        OL = CS
    else:
        output = s.outputs[0].output(0)
        s[conv].set_scope("shared")
        OL = conv

    # Schedule for autotvm
    # cfg.define_knob("MTile", [8, 16, 32, 64])
    cfg.define_knob("NTile", [32, 64, 128, 256])
    cfg.define_knob("KTile", [32, 64, 128, 256])
    cfg.define_knob("block_row_warps", [1, 2, 4])
    cfg.define_knob("block_col_warps", [1, 2, 4])
    cfg.define_knob("chunk", [1, 2, 4, 8])
    cfg.define_knob("vector_width", [1, 4, 8, 16])
    cfg.define_knob("offset", [0, 16])
    cfg.define_knob("wmma_m", [8, 16, 32])
    cfg.define_knob("vthread", [1, 2])

    MTile = cfg["MTile"].val
    NTile = cfg["NTile"].val
    KTile = cfg["KTile"].val
    block_row_warps = cfg["block_row_warps"].val
    block_col_warps = cfg["block_col_warps"].val
    chunk = cfg["chunk"].val
    vector_width = cfg["vector_width"].val
    offset = cfg["offset"].val
    vthread = cfg["vthread"].val
    wmma_m = cfg["wmma_m"].val
    # wmma_m = 8

    wmma_k = 16
    wmma_n = 16
    if wmma_m == 8:
        wmma_n = 32
    elif wmma_m == 32:
        wmma_n = 8
    wmma_shape = (wmma_m, wmma_n, wmma_k)

    warp_size = 32

    block_x = te.thread_axis("blockIdx.x")
    block_y = te.thread_axis("blockIdx.y")
    block_z = te.thread_axis("blockIdx.z")
    thread_x = te.thread_axis("threadIdx.x")
    thread_y = te.thread_axis("threadIdx.y")
    thread_z = te.thread_axis("threadIdx.z")
    thread_vx = te.thread_axis((0, vthread), "vthread", name="vx")
    thread_vy = te.thread_axis((0, vthread), "vthread", name="vy")

    # Define the intrin strides
    def get_strides(extents):
        return [np.prod(extents[i:]).tolist() for i in range(len(extents))]

    AS_align = KTile + offset
    WS_align = NTile + offset
    CS_align = NTile + offset
    NFrag = (NTile + block_col_warps - 1) // block_col_warps
    KFrag = chunk * wmma_k
    AS_shape = [wmma_m, wmma_k]
    AF_shape = [wmma_m, wmma_k]
    WS_shape = [wmma_k, wmma_n]
    WF_shape = [wmma_k, wmma_n]
    CS_shape = [wmma_m, wmma_n]
    CF_shape = [wmma_m, wmma_n]
    AS_strides = get_strides([AS_align, 1])
    AF_strides = get_strides([KFrag, 1])
    WS_strides = get_strides([WS_align, 1])
    WF_strides = get_strides([NFrag, 1])
    CF_strides = get_strides([NFrag, 1])
    CS_strides = get_strides([NTile, 1])

    # Schedule for output
    b, h, w, n = output.op.axis
    block_k = s[output].fuse(b, h)
    block_j, m = s[output].split(w, factor=MTile)
    block_i, n = s[output].split(n, factor=NTile)
    s[output].reorder(block_k, block_j, block_i, m, n)
    t = s[output].fuse(m, n)
    t, ti = s[output].split(t, factor=vector_width)
    t, tx = s[output].split(t, factor=warp_size)
    t, ty = s[output].split(t, factor=block_col_warps)
    t, tz = s[output].split(t, factor=block_row_warps)
    s[output].bind(block_k, block_z)
    s[output].bind(block_j, block_y)
    s[output].bind(block_i, block_x)
    s[output].bind(tz, thread_z)
    s[output].bind(ty, thread_y)
    s[output].bind(tx, thread_x)

    # Schedule for conv
    s[OL].compute_at(s[output], block_i)
    b, h, w, n = OL.op.axis
    s[OL].storage_align(h, CS_align - 1, CS_align)
    tz, m = s[OL].split(w, nparts=block_row_warps)
    ty, n = s[OL].split(n, nparts=block_col_warps)
    s[OL].reorder(tz, ty, m, n)
    t = s[OL].fuse(m, n)
    t, ti = s[OL].split(t, factor=vector_width)
    tx, _ = s[OL].split(t, nparts=warp_size)
    s[OL].bind(tz, thread_z)
    s[OL].bind(ty, thread_y)
    s[OL].bind(tx, thread_x)

    # Schedule for wmma store
    s[CS_padded].compute_at(s[OL], h)
    b, h, w, n = CS_padded.op.axis
    s[CS_padded].storage_align(w, CS_align - 1, CS_align)
    # _, w = s[CS_padded].split(w, factor=MTile)
    tz, m = s[CS_padded].split(w, nparts=block_row_warps)
    ty, n = s[CS_padded].split(n, nparts=block_col_warps)
    mo_cs, mi_cs = s[CS_padded].split(m, factor=wmma_m)
    no_cs, ni_cs = s[CS_padded].split(n, factor=wmma_n)
    s[CS_padded].reorder(tz, ty, mo_cs, no_cs, mi_cs, ni_cs)
    s[CS_padded].bind(tz, thread_z)
    s[CS_padded].bind(ty, thread_y)

    # Schedule for wmma compute
    s[CF_padded].compute_at(s[CS_padded], ty)
    b, h, w, n = CF_padded.op.axis
    mo_cf, mi_cf = s[CF_padded].split(w, factor=wmma_m)
    no_cf, ni_cf = s[CF_padded].split(n, factor=wmma_n)
    kfo, kfi = s[CF_padded].split(rc, factor=KTile)
    kfi, _kfi = s[CF_padded].split(kfi, factor=wmma_k)
    _kfo, kfi = s[CF_padded].split(kfi, factor=chunk)
    s[CF_padded].reorder(rx, ry, kfo, _kfo, kfi, mo_cf, no_cf, mi_cf, ni_cf, _kfi)

    s[AS].compute_at(s[CF_padded], kfo)
    s[WS].compute_at(s[CF_padded], kfo)
    s[AF].compute_at(s[CF_padded], _kfo)
    s[WF].compute_at(s[CF_padded], _kfo)

    # Schedule for input's shared memory
    b, h, w, k = AS.op.axis
    s[AS].storage_align(w, AS_align - 1, AS_align)
    _, w = s[AS].split(w, factor=MTile)
    tz, m = s[AS].split(w, nparts=block_row_warps)
    ty, k = s[AS].split(k, nparts=block_col_warps)
    s[AS].reorder(tz, ty, m, k)
    t = s[AS].fuse(m, k)
    t, ti = s[AS].split(t, factor=vector_width)
    tx, _ = s[AS].split(t, nparts=warp_size)
    s[AS].bind(tz, thread_z)
    s[AS].bind(ty, thread_y)
    s[AS].bind(tx, thread_x)
    s[AS].vectorize(ti)

    # Schedule for weight's shared memory
    h, w, ic, oc = WS.op.axis
    s[WS].storage_align(ic, WS_align - 1, WS_align)
    tvx, ic = s[WS].split(ic, nparts=vthread)
    tvy, oc = s[WS].split(oc, nparts=vthread)
    tz, ic = s[WS].split(ic, nparts=block_row_warps)
    ty, oc = s[WS].split(oc, nparts=block_col_warps)
    s[WS].reorder(tvx, tvy, tz, ty, ic, oc)
    t = s[WS].fuse(ic, oc)
    t, ti = s[WS].split(t, factor=vector_width)
    tx, _ = s[WS].split(t, nparts=warp_size)
    s[WS].bind(tz, thread_z)
    s[WS].bind(ty, thread_y)
    s[WS].bind(tx, thread_x)
    s[WS].bind(tvx, thread_vx)
    s[WS].bind(tvy, thread_vy)
    s[WS].vectorize(ti)

    # Schedule for input's local memory
    b, h, w, k = AF.op.axis
    mo_af, mi_af = s[AF].split(w, factor=wmma_m)
    ko_af, ki_af = s[AF].split(k, factor=wmma_k)
    s[AF].reorder(mo_af, ko_af, mi_af, ki_af)

    # Schedule for weight's local memory
    h, w, ic, oc = WF.op.axis
    ico_wf, ici_wf = s[WF].split(ic, factor=wmma_k)
    oco_wf, oci_wf = s[WF].split(oc, factor=wmma_n)
    s[WF].reorder(h, w, ico_wf, oco_wf, ici_wf, oci_wf)

    # tensorize the wmma process
    AF_gemm = te.placeholder(AF_shape, name='A', dtype='int8')
    WF_gemm = te.placeholder(WF_shape, name='B', dtype='int8')
    k_gemm = te.reduce_axis((0, wmma_k), name='k')
    CF_compute = te.compute(CF_shape,
                            lambda m, n: te.sum(AF_gemm[m, k_gemm].astype('int32') * WF_gemm[k_gemm, n].astype('int32'),
                                                axis=k_gemm),
                            name='C')

    s[AF].tensorize(mi_af, intrin_wmma_load_matrix_A(AF_strides, AS_strides, wmma_shape, 'row_major',
                                                     AS_shape, AF_shape, 'int8', 'shared'))
    s[WF].tensorize(ici_wf, intrin_wmma_load_matrix_W(WF_strides, WS_strides, wmma_shape, 'row_major',
                                                      WS_shape, WF_shape, 'int8', 'shared'))
    s[CS_padded].tensorize(mi_cs, intrin_wmma_store_matrix(CS_strides, CF_strides, wmma_shape, 'row_major',
                                                    'int32', CF_shape, CS_shape, 'shared'))
    s[CF_padded].tensorize(mi_cf, intrin_wmma_gemm(AF_gemm, WF_gemm, CF_compute, AF_strides, WF_strides,
                                                   CF_strides, wmma_shape, 'row_major', 'row_major', 'row_major'))

For example, when batch=1, in_size=28, in_channel=128, num_filter=128, kernel_size=3, MTile=32, the program has an error: TVMError: Tensorize failed, split condition tir.likely(((ax2.inner.inner + (ax2.inner.outer*8)) < 28)) relies on var defined inside tensorize scope.

Here’s the lower code before tensorize.

primfn(input_1: handle, weight_1: handle, bias_1: handle, compute_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {bias: Buffer(bias_2: handle, int32, [1, 1, 1, 128], []),
             input: Buffer(input_2: handle, int8, [1, 28, 28, 128], []),
             compute: Buffer(compute_2: handle, int8, [1, 28, 28, 128], []),
             weight: Buffer(weight_2: handle, int8, [3, 3, 128, 128], [])}
  buffer_map = {input_1: input, weight_1: weight, bias_1: bias, compute_1: compute} {
  attr [IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")] "thread_extent" = 28;
  attr [conv_padded.wmma.accumulator: handle] "storage_scope" = "wmma.accumulator";
  allocate(conv_padded.wmma.accumulator, int32, [896]);
  attr [input_pad_w.shared: handle] "storage_scope" = "shared";
  allocate(input_pad_w.shared, int8, [896]);
  attr [weight.shared: handle] "storage_scope" = "shared";
  allocate(weight.shared, int8, [1024]);
  attr [input_pad_w.shared.wmma.matrix_a: handle] "storage_scope" = "wmma.matrix_a";
  allocate(input_pad_w.shared.wmma.matrix_a, int8, [448]);
  attr [weight.shared.wmma.matrix_b: handle] "storage_scope" = "wmma.matrix_b";
  allocate(weight.shared.wmma.matrix_b, int8, [512]);
  attr [conv_padded.wmma.accumulator.shared: handle] "storage_scope" = "shared";
  allocate(conv_padded.wmma.accumulator.shared, int32, [896]);
  attr [conv: handle] "storage_scope" = "shared";
  allocate(conv, int32, [900]);
  attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 1;
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 4 {
    attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
    attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1 {
      for (w.c.outer.init: int32, 0, 4) {
        for (w.c.inner.init: int32, 0, 8) {
          for (y.c.inner.init: int32, 0, 32) {
            if @tir.likely((((w.c.outer.init*8) + w.c.inner.init) < 28), dtype=bool) {
              conv_padded.wmma.accumulator[(((w.c.outer.init*256) + (w.c.inner.init*32)) + y.c.inner.init)] = 0
            }
          }
        }
      }
      for (rx: int32, 0, 3) {
        for (ry: int32, 0, 3) {
          for (rc.outer: int32, 0, 4) {
            attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
            for (ax2.inner.inner.ax3.inner.fused.outer.inner: int32, 0, 32) {
              if @tir.likely((threadIdx.x < 28), dtype=bool) {
                input_pad_w.shared[((threadIdx.x*32) + ax2.inner.inner.ax3.inner.fused.outer.inner)] = @tir.if_then_else(((((1 <= (blockIdx.z
 + rx)) && ((blockIdx.z + rx) < 29)) && (1 <= (threadIdx.x + ry))) && ((threadIdx.x + ry) < 29)), (int8*)input_2[(((((((blockIdx.z*3584) + (r
x*3584)) + (threadIdx.x*128)) + (ry*128)) + (rc.outer*32)) + ax2.inner.inner.ax3.inner.fused.outer.inner) - 3712)]), 0i8, dtype=int8)
              }
            }
            attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
            for (ax2.inner.inner.ax3.inner.inner.fused.outer.inner: int32, 0, 32) {
              weight.shared[((threadIdx.x*32) + ax2.inner.inner.ax3.inner.inner.fused.outer.inner)] = (int8*)weight_2[((((((rx*49152) + (ry*1
6384)) + (rc.outer*4096)) + (threadIdx.x*128)) + (blockIdx.x*32)) + ax2.inner.inner.ax3.inner.inner.fused.outer.inner)])
            }
            for (rc.inner.outer.outer: int32, 0, 2) {
              for (ax2.outer: int32, 0, 4) {
                for (ax2.inner: int32, 0, 8) {
                  for (ax3.inner: int32, 0, 16) {
                    if @tir.likely((((ax2.outer*8) + ax2.inner) < 28), dtype=bool) {
                      input_pad_w.shared.wmma.matrix_a[(((ax2.outer*128) + (ax2.inner*16)) + ax3.inner)] = (int8*)input_pad_w.shared[((((ax2.
outer*256) + (ax2.inner*32)) + (rc.inner.outer.outer*16)) + ax3.inner)])
                    }
                  }
                }
              }
              for (ax2.inner_1: int32, 0, 16) {
                for (ax3.inner_1: int32, 0, 32) {
                  weight.shared.wmma.matrix_b[((ax2.inner_1*32) + ax3.inner_1)] = (int8*)weight.shared[(((rc.inner.outer.outer*512) + (ax2.in
ner_1*32)) + ax3.inner_1)])
                }
              }
              for (w.c.outer: int32, 0, 4) {
                for (w.c.inner: int32, 0, 8) {
                  for (y.c.inner: int32, 0, 32) {
                    for (rc.inner.inner: int32, 0, 16) {
                      if @tir.likely((((w.c.outer*8) + w.c.inner) < 28), dtype=bool) {
                        conv_padded.wmma.accumulator[(((w.c.outer*256) + (w.c.inner*32)) + y.c.inner)] = ((int32*)conv_padded.wmma.accumulato
r[(((w.c.outer*256) + (w.c.inner*32)) + y.c.inner)]) + (cast(int32, (int8*)input_pad_w.shared.wmma.matrix_a[(((w.c.outer*128) + (w.c.inner*16
)) + rc.inner.inner)]))*cast(int32, (int8*)weight.shared.wmma.matrix_b[((rc.inner.inner*32) + y.c.inner)]))))
                      }
                    }
                  }
                }
              }
            }
          }
        }
      }
      for (ax2.inner.outer: int32, 0, 4) {
        for (ax2.inner.inner: int32, 0, 8) {
          for (ax3.inner.inner: int32, 0, 32) {
            if @tir.likely((((ax2.inner.outer*8) + ax2.inner.inner) < 28), dtype=bool) {
              conv_padded.wmma.accumulator.shared[(((ax2.inner.outer*256) + (ax2.inner.inner*32)) + ax3.inner.inner)] = (int32*)conv_padded.w
mma.accumulator[(((ax2.inner.outer*256) + (ax2.inner.inner*32)) + ax3.inner.inner)])
            }
          }
        }
      }
    }
    attr [IterVar(threadIdx.z, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
    attr [IterVar(threadIdx.y, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1;
    attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
    for (w.inner.c.inner.fused.outer.inner: int32, 0, 28) {
      conv[((threadIdx.x*28) + w.inner.c.inner.fused.outer.inner)] = (int32*)conv_padded.wmma.accumulator.shared[((threadIdx.x*28) + w.inner.
c.inner.fused.outer.inner)])
    }
    for (i2.inner.i3.inner.fused.outer.outer.outer.outer: int32, 0, 32) {
      attr [IterVar(threadIdx.z, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
      attr [IterVar(threadIdx.y, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 1;
      attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
      if @tir.likely((i2.inner.i3.inner.fused.outer.outer.outer.outer < 28), dtype=bool) {
        compute_2[((((blockIdx.z*3584) + (i2.inner.i3.inner.fused.outer.outer.outer.outer*128)) + (blockIdx.x*32)) + threadIdx.x)] = cast(int
8, max(min(@tir.round((cast(float32, ((int32*)conv[((i2.inner.i3.inner.fused.outer.outer.outer.outer*32) + threadIdx.x)]) + (int32*)bias_2[((
blockIdx.x*32) + threadIdx.x)])))*0.00254f32), dtype=float32), 127f32), 0f32))
      }
    }
  }
}

I want to remove these likely statements, such as if @tir.likely((((ax2.outer*8) + ax2.inner) < 28), dtype=bool) and if @tir.likely((((w.c.outer*8) + w.c.inner) < 28), dtype=bool).

Does anyone know how to deal with the schedule of different shapes in in contiguous stages?

This makes me very confused. Does anyone know it? :pray: :pray: