Any documentation on winograd convolution scheduling on CUDA?

TVM version 0.6

I know TVM supports winograd convolution but I cannot find any document on it.
There are python functions in topi.cuda.schedule_conv2d_nchw_cuda but not documented.

And I finally manage to make up a piece of code that seems to do it.
Auto tuning came out ok but when actually running the tuned function
it says

Traceback (most recent call last):
  File "manual_conv2d_winograd.py", line 271, in <module>
    func = tvm.build(s, arg_bufs, target='cuda')
  File "/workspace/tvm/python/tvm/build_module.py", line 574, in build
    binds=binds)
  File "/workspace/tvm/python/tvm/build_module.py", line 417, in lower
    return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
  File "tvm/_ffi/_cython/./function.pxi", line 310, in tvm._ffi._cy3.core.FunctionBase.__call__
  File "tvm/_ffi/_cython/./function.pxi", line 255, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./base.pxi", line 171, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /workspace/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call<tvm::LoweredFunc, 5, tvm::LoweredFunc (*)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool)>(tvm::LoweredFunc (* const&)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)+0x2b) [0x7f5c5e66c058]
  [bt] (7) /workspace/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<tvm::LoweredFunc, 5, 0, tvm::LoweredFunc (*)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool)>::run<>(tvm::LoweredFunc (* const&)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)+0x5a) [0x7f5c5e66d921]
  [bt] (6) /workspace/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<tvm::LoweredFunc, 4, 1, tvm::LoweredFunc (*)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool)>::run<tvm::runtime::TVMArgValue>(tvm::LoweredFunc (* const&)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&)+0x73) [0x7f5c5e66fab2]
  [bt] (5) /workspace/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<tvm::LoweredFunc, 3, 2, tvm::LoweredFunc (*)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(tvm::LoweredFunc (* const&)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&)+0x8a) [0x7f5c5e671172]
  [bt] (4) /workspace/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<tvm::LoweredFunc, 2, 3, tvm::LoweredFunc (*)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(tvm::LoweredFunc (* const&)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&)+0xa4) [0x7f5c5e6720fb]
  [bt] (3) /workspace/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<tvm::LoweredFunc, 1, 4, tvm::LoweredFunc (*)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(tvm::LoweredFunc (* const&)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&)+0xba) [0x7f5c5e6726db]
  [bt] (2) /workspace/tvm/build/libtvm.so(void tvm::runtime::detail::unpack_call_dispatcher<tvm::LoweredFunc, 0, 5, tvm::LoweredFunc (*)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(tvm::LoweredFunc (* const&)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&)+0x11c) [0x7f5c5e672a39]
  [bt] (1) /workspace/tvm/build/libtvm.so(tvm::ir::MakeAPI(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool)+0x2007) [0x7f5c5ea3f736]
  [bt] (0) /workspace/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x34) [0x7f5c5e5e9656]
  File "/workspace/tvm/src/pass/make_api.cc", line 187
TVMError: Not all Vars are passed in api_args:  'threadIdx.x'  does not appear in api_args

Here is the python code.

import logging
import sys
import numpy as np
import os

import tvm
import topi
from topi.testing import conv2d_nchw_python
import args_parsing
import tvm.contrib.nnpack

from tvm import autotvm
import json
import yaml
import topi.cuda.conv2d_winograd


def schedule_winograd_cuda(cfg, s, output, pre_computed):
    """Schedule winograd template"""
    # get stages
    inverse = s[output].op.input_tensors[0]
    bgemm, A = s[inverse].op.input_tensors
    kernel_pack, data_pack = s[bgemm].op.input_tensors
    input_tile, B = s[data_pack].op.input_tensors
    pad_data = s[input_tile].op.input_tensors[0]

    # data transform
    s[B].compute_inline()

    data_l = s.cache_write(data_pack, 'local')
    eps, nu, c, p = s[data_l].op.axis
    r_a, r_b = s[data_l].op.reduce_axis
    for axis in [eps, nu, r_a, r_b]:
        s[data_l].unroll(axis)

    eps, nu, c, p = s[data_pack].op.axis
    p, pi = s[data_pack].split(p, 1)
    fused = s[data_pack].fuse(c, p)
    bb, tt = s[data_pack].split(fused, 128)
    s[data_pack].reorder(bb, tt, pi, eps, nu)
    s[data_pack].bind(bb, tvm.thread_axis("blockIdx.x"))
    s[data_pack].bind(tt, tvm.thread_axis("threadIdx.x"))

    s[data_l].compute_at(s[data_pack], pi)
    s[input_tile].compute_at(s[data_pack], pi)
    s[pad_data].compute_inline()

    # transform kernel
    if not pre_computed:
        kernel, G = s[kernel_pack].op.input_tensors
        eps, nu, ci, co = s[kernel_pack].op.axis
        if autotvm.GLOBAL_SCOPE.in_tuning:
            # skip this part during tuning to make recrods accurate
            # this part will be pre-computed during NNVM's pre-compute optimization pass
            s[G].pragma(s[G].op.axis[0], 'debug_skip_region')
            s[kernel_pack].pragma(eps, 'debug_skip_region')
        else:
            s[G].compute_inline()
            r_a, r_b = s[kernel_pack].op.reduce_axis
            for axis in [eps, nu, r_a, r_b]:
                s[kernel_pack].unroll(axis)

            fused = s[kernel_pack].fuse(ci, co)
            bb, tt = s[kernel_pack].split(fused, 128)
            s[kernel_pack].reorder(bb, tt, eps, nu, r_a, r_b)
            s[kernel_pack].bind(bb, tvm.thread_axis("blockIdx.x"))
            s[kernel_pack].bind(tt, tvm.thread_axis("threadIdx.x"))
    else:
        kernel = kernel_pack

    if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
        s[kernel].compute_inline()

    ##### space definition begin #####
    b1, b2, y, x = s[bgemm].op.axis
    rc = s[bgemm].op.reduce_axis[0]
    alpha = topi.util.get_const_int(b1.dom.extent)

    cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4,
                     filter=lambda x: x.size[-3:] == [1, 1, 1])
    cfg.define_split("tile_y", y, num_outputs=4)
    cfg.define_split("tile_x", x, num_outputs=4)
    cfg.define_split("tile_rc", rc, num_outputs=2)
    cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
    target = tvm.target.current_target()
    if target.target_name in ['nvptx', 'rocm']:
        cfg.define_knob("unroll_explicit", [1])
    else:
        cfg.define_knob("unroll_explicit", [0, 1])
    ##### space definition end #####

    # batch gemm
    C = bgemm
    A0, B0 = kernel_pack, data_pack

    OL = s.cache_write(C, 'local')
    AA = s.cache_read(A0, 'shared', [OL])
    BB = s.cache_read(B0, 'shared', [OL])

    b = s[bgemm].fuse(b1, b2)

    # tile and bind spatial axes
    bgemm_scope, b = s[bgemm].split(b, nparts=1)
    bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b)
    by, vy, ty, yi = cfg["tile_y"].apply(s, C, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
    s[C].bind(bz, tvm.thread_axis("blockIdx.z"))
    s[C].bind(by, tvm.thread_axis("blockIdx.y"))
    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[C].bind(vz, tvm.thread_axis("vthread"))
    s[C].bind(vy, tvm.thread_axis("vthread"))
    s[C].bind(vx, tvm.thread_axis("vthread"))
    s[C].bind(tz, tvm.thread_axis("threadIdx.z"))
    s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
    s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi)

    # tile reduction axes
    s[OL].compute_at(s[C], tx)
    b1, b2, y, x = s[OL].op.axis
    b = s[OL].fuse(b1, b2)
    rc, = s[OL].op.reduce_axis
    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
    s[OL].reorder(rco, rci, b, y, x)

    s[AA].compute_at(s[OL], rco)
    s[BB].compute_at(s[OL], rco)

    # cooperative fetching
    for load in [AA, BB]:
        fused = s[load].fuse(*list(s[load].op.axis))
        fused, tx = s[load].split(fused, cfg["tile_x"].size[2])
        fused, ty = s[load].split(fused, cfg["tile_y"].size[2])
        fused, tz = s[load].split(fused, cfg["tile_b"].size[2])
        s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
        s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
        s[load].bind(tx, tvm.thread_axis("threadIdx.x"))

    s[C].pragma(bgemm_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
    s[C].pragma(bgemm_scope, 'unroll_explicit', cfg['unroll_explicit'].val)

    # schedule inverse, output and fusion
    if output.op in s.outputs:
        OL = None
    else:
        OL = output
        s[OL].set_scope('local')
        output = s.outputs[0]

    m = alpha - 3 + 1
    n, co, h, w = s[output].op.axis
    ho, wo, hi, wi = s[output].tile(h, w, m, m)
    inverse_scope, n = s[output].split(n, nparts=1)

    fused = s[output].fuse(n, co, ho, wo)
    bb, tt = s[output].split(fused, 128)

    s[output].bind(bb, tvm.thread_axis("blockIdx.x"))
    s[output].bind(tt, tvm.thread_axis("threadIdx.x"))

    if OL is not None:
        s[OL].compute_at(s[output], tt)

    s[A].compute_inline()
    co, p, vh, vw = s[inverse].op.axis
    r_a, r_b = s[inverse].op.reduce_axis
    for axis in [vh, vw, r_a, r_b]:
        s[inverse].unroll(axis)
    s[inverse].compute_at(s[output], tt)

    return s

@tvm.autotvm.task.register('conv2d_winograd_my')
def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
    assert N == 1, "Only consider batch_size = 1 in this template"

    data = tvm.placeholder((N, CI, H, W), name='data')
    kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
    # conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype='float32')
    conv = topi.nn.conv2d(data, kernel, stride, padding, dilation=1, layout='NCHW', out_dtype='float32')
    s = tvm.create_schedule([conv.op])
    print("-----------------------original IR")
    print(tvm.lower(s, [conv, data, kernel], simple_mode=True))
    print("-------------------------------------")
    cfg = autotvm.DispatchContext.current.query(tvm.target.cuda(), None)
    cfg.template_key = 'winograd'
    # cfg = autotvm.get_config()
    s = schedule_winograd_cuda(cfg, s, conv, pre_computed=False)

    # print("----------------------------- raw graph")
    # print(tvm.lower(s, [data, kernel, conv], simple_mode=True))

    # return s, [data, kernel, conv]
    return s, [data, kernel, conv]


# logging config (for printing tuning log to screen)
logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))

with open(sys.argv[1], "r") as f:
    config = yaml.load(f)

for key in config.keys():
    config[key] = list(map(int, config[key].split(",")))

N, CI, W, H, CO, stride, pad, KW, KH = config['hyper_parameters']
tile_f, tile_x, tile_y = config['tile_f'], config['tile_x'], config['tile_y']
tile_rc, tile_rx, tile_ry = config['tile_rc'], config['tile_rx'], config['tile_ry']
params = {
    "i": [
        "cuda", "conv2d_no_batching", [N, W, H, CO, CI, KW, KH, [stride, stride], [pad, pad]], {},
        ["conv2d_no_batching", N, W, H, CO, CI, KW, KH, [stride, stride], [pad, pad]],
        {
            "i": 462435094,
            "t": "",
            "c": None,
            "e": [
                ["tile_f", "sp", tile_f],
                ["tile_y", "sp", tile_y],
                ["tile_x", "sp", tile_x],
                ["tile_rc", "sp", tile_rc],
                ["tile_ry", "sp", tile_ry],
                ["tile_rx", "sp", tile_rx]
            ]
        }
    ],
    "r": [[0.001660518793814433], 0, 3.7641751766204834, 1564989891.256183],
    "v": 0.1
}

strides = (stride, stride)
padding = (pad, pad)
print("N = %d\nCI = %d\nCO = %d\nstride = %d\npad = %d" % (N, CI, CO, stride, pad))
print("KHxKW = %dx%d" % (KH, KW))
print("HxW = %dx%d" % (H, W))
print("padded Input = %dx%d" % (H + pad*2, W + pad*2))
print("HOxWO = %dx%d" % ((H - KH + 1 + pad*2) / stride, (W - KW + 1 + pad*2) / stride))
# N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 1, 1, 1, 1, (1, 1), (0, 0)


data = tvm.placeholder((N, CI, H, W), name='data')
kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
task = autotvm.task.create(func_name='conv2d_winograd_my',
                           args=(N, H, W, CO, CI, KH, KW, stride, pad),
                           target='cuda', target_host='llvm', template_key='winograd')

measure_option = autotvm.measure_option(
    builder=autotvm.LocalBuilder(),
    runner=autotvm.LocalRunner(repeat=5, min_repeat_ms=100, timeout=4)
)

# if os.environ.get("AUTOTUNE") == "1":
tuner = autotvm.tuner.XGBTuner(task)
# tuner.tune(n_trial=1, measure_option=measure_option, callbacks=[autotvm.callback.log_to_file('conv2d_winograd.log')])

# log_file = "conv2d.log.tmp"
# with open("conv2d.log.tmp", "w") as f:
#     f.write(json.dumps(params))

config_space = autotvm.ConfigSpace()
config_space.template_key = "winograd"

# with autotvm.apply_history_best('conv2d_winograd.log'):
with tvm.autotvm.task.ApplyConfig(config_space):
    with tvm.target.create("cuda"):
        # s, arg_bufs = task.instantiate(config_space)
        # s, arg_bufs = conv2d_no_batching(N, H, W, CO, CI, KH, KW, strides, padding)
        s, arg_bufs = task.extra_data
        print(tvm.lower(s, arg_bufs, simple_mode=True))
        func = tvm.build(s, arg_bufs, target='cuda')
        print(func.imported_modules[0].get_source())

# check correctness
a_np = np.random.normal(size=(N, CI, H, W)).astype(np.float32)
w_np = np.random.normal(size=(CO, CI, KH, KW)).astype(np.float32)
c_np = conv2d_nchw_python(a_np, w_np, strides, padding)

ctx = tvm.gpu()
a_tvm = tvm.nd.array(a_np, ctx=ctx)
w_tvm = tvm.nd.array(w_np, ctx=ctx)
c_tvm = tvm.nd.empty(c_np.shape, ctx=ctx)
func(a_tvm, w_tvm, c_tvm)

#tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-1)

# Evaluate running time. Here we choose a large repeat number (400) to reduce the noise
# and the overhead of kernel launch. You can also use nvprof to validate the result.
evaluator = func.time_evaluator(func.entry_name, ctx, number=1)
print('Time cost of this operator: %fms' % (evaluator(a_tvm, w_tvm, c_tvm).mean*1e3))


# config.yaml
hyper_parameters: 1, 32, 256,256, 32, 1,1, 3,3

Run with python3 conv2d_winograd.py config.yaml

@merrymercy @cbalint13