Implement Convolution Using TVM

Hi there,

I am a first-time TVM user. I am also not very familiar with CUDA GPU programming. I want to write a tensor convolution routine using tvm. I wish the routine to be at least somewhat optimized. Here is a specification of my problem.

1. All the tensors are in NCHW form.
2. Vertical and horizontal padding and stride are required.
3. The code will be executed on an NVIDIA GPU with CUDA, cudnn, cublas etc. support.

I am testing a new way of doing convolution, so I really need a base implementation to run my experiment. Built-in convolution routines such as cuDNN are too optimized and are not good baselines for me. I tried to use cutlass but it is too complicated.

By reading online tutorials like How to optimize convolution on GPU and Tuning High Performance Convolution on NVIDIA GPUs, I assemble the following code.

import numpy as np
import tvm

“”" The computatin we are trying to support is:
void* tensorConvolution(void* input, void* filter,
    int vertical_pad, int horizontal_pad,
    int vertical_stride, int horizontal_stride,
    int conv_mode, int conv_groups); “”"

########################
# Step 1. Describe the computation
########################

# Tensor input and filter dimensions
input_batch = tvm.var(“input_batch”)
input_channel = tvm.var(“input_channel”)
input_width = tvm.var(“input_width”)
input_height = tvm.var(“input_height”)
filter_num = tvm.var(“filter_num”)
filter_size = tvm.var(“filter_size”)

# Additional parameters
# conv_mode is always 1 and not used. I currently does not handle conv_group.
vertical_pad = tvm.var(“vertical_pad”)
horizontal_pad = tvm.var(“horizontal_pad”)
vertical_stride = tvm.var(“vertical_stride”)
horizontal_stride = tvm.var(“horizontal_stride”)

# Input data matrix and filter matrix. NCHW
InputDataMatrix = tvm.placeholder((input_batch, input_channel, input_height, input_width),
    name=“InputDataMatrix”)
FilterMatrix = tvm.placeholder((filter_num, input_channel, filter_size, filter_size),
    name=“FilterMatrix”)

# Output dimensions
output_width = (input_width - filter_size + 2 * horizontal_pad) // horizontal_stride + 1
output_height = (input_height - filter_size + 2 * vertical_pad) // vertical_stride + 1

# Define padding
InputDataMatrixPad = tvm.compute(
    (input_batch, input_channel, output_height + 2 * vertical_pad, output_width + 2 * horizontal_pad),
    lambda nidx, cidx, hidx, widx: tvm.if_then_else(
        tvm.all(hidx - vertical_pad >= 0, hidx - vertical_pad < input_height,
            widx - horizontal_pad >= 0, widx - horizontal_pad < input_width),
        InputDataMatrix[nidx, cidx, hidx - vertical_pad, widx - horizontal_pad], tvm.const(0., “float32”)
    ), name=“InputDataMatrixPad”
)

# Create reduction variables
rc = tvm.reduce_axis((0, input_channel), name=‘rc’)
rh = tvm.reduce_axis((0, filter_size), name=‘rh’)
rw = tvm.reduce_axis((0, filter_size), name=‘rw’)

# Compute the convolution
OutputMatrix = tvm.compute(
    (input_batch, filter_num, output_height, output_width),
    lambda nidx, fidx, hidx, widx: tvm.sum(
        InputDataMatrixPad[nidx, rc, hidx * vertical_stride + rh, widx * horizontal_stride + rw] *
            FilterMatrix[fidx, rc, rh, rw],
        axis=[rc, rh, rw]
    ), name=“OutputMatrix”
)

########################
# Step 2. Design the memory hierarchy
########################

# Schedule. Do padding inline.
schedule = tvm.create_schedule(OutputMatrix.op)
schedule[InputDataMatrixPad].compute_inline()

# Memory read and write
InputDataMatrixPadInShared = schedule.cache_read(InputDataMatrixPad, ‘shared’, [OutputMatrix])
FilterMatrixInShared = schedule.cache_read(FilterMatrix, “shared”, [OutputMatrix])
InputDataMatrixPadInLocal = schedule.cache_read(InputDataMatrixPadInShared, “local”,
    [OutputMatrix])
FilterMatrixInLocal = schedule.cache_read(FilterMatrixInShared, “local”, [OutputMatrix])
OutputMatrixInLocal = schedule.cache_write(OutputMatrix, “local”)

########################
# Step 3. Blocking
########################

# For each input image x num filter x image channel, a block is created.
nidx, fidx, hidx, widx = schedule[OutputMatrix].op.axis
schedule[OutputMatrix].bind(nidx, tvm.thread_axis(“blockIdx.z”))
schedule[OutputMatrix].bind(fidx, tvm.thread_axis(“blockIdx.y”))
rcidx, rhidx, rwidx = schedule[OutputMatrix].op.reduce_axis
schedule[OutputMatrix].bind(rcidx, tvm.thread_axis(“blockIdx.x”))

# Then, we are left with 2D convolution
# Input[height * weight] conv. Filter[size * size] = Output[hidx * widx]
# We split the workload by a factor of 32
hwidxSplitFactor = 32
hwidxFused = schedule[OutputMatrix].fuse(hidx, widx)
hwidxFusedSplit, _ = schedule[OutputMatrix].split(hwidxFused, factor=hwidxSplitFactor)
schedule[OutputMatrix].bind(hwidxFusedSplit, tvm.thread_axis(“threadIdx.x”))

########################
# Step 4. Memory Fetching
########################

schedule[OutputMatrixInLocal].compute_at(schedule[OutputMatrix], hwidxFusedSplit)
rcidx, rhidx, rwidx = schedule[OutputMatrixInLocal].op.reduce_axis
rhrwFused = schedule[OutputMatrixInLocal].fuse(rhidx, rwidx)
schedule[InputDataMatrixPadInShared].compute_at(schedule[OutputMatrixInLocal], rhrwFused)
schedule[InputDataMatrixPadInLocal].compute_at(schedule[OutputMatrixInLocal], rhrwFused)
schedule[FilterMatrixInShared].compute_at(schedule[OutputMatrixInLocal], rhrwFused)
schedule[FilterMatrixInLocal].compute_at(schedule[OutputMatrixInLocal], rhrwFused)

########################
# Step 5. Testing
########################

func = tvm.build(schedule, [InputDataMatrix, FilterMatrix, vertical_pad, horizontal_pad,
vertical_stride, horizontal_stride, OutputMatrix], ‘cuda’)
ctx = tvm.gpu(0)

input_batch_v = 2
input_channel_v = 3
input_width_v = 3
input_height_v = 3
filter_num_v = 3
filter_size_v = 3
vertical_pad_v = 1
horizontal_pad_v = 1
vertical_stride_v = 1
horizontal_stride_v = 1

# Input
a_np = np.random.uniform(size=(input_batch_v, input_channel_v, input_height_v,
    input_width_v)).astype(InputDataMatrix.dtype)
w_np = np.random.uniform(size=(filter_num_v, input_channel_v, filter_size_v,
    filter_size_v)).astype(FilterMatrix.dtype)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)

# Output
output_width_v = (input_width_v - filter_size_v + 2 * horizontal_pad_v) // horizontal_stride_v + 1
output_height_v = (input_height_v - filter_size_v + 2 * vertical_pad_v) // vertical_stride_v + 1
b = tvm.nd.array(np.zeros((input_batch_v, filter_num_v, output_height_v, output_width_v),
    dtype=OutputMatrix.dtype), ctx)

func(a, w, b, vertical_pad_v, horizontal_pad_v, vertical_stride_v, horizontal_stride_v)
evaluator = func.time_evaluator(func.entry_name, ctx, number=1)
print(‘Convolution: %f ms’ % (evaluator(a, w, b).mean * 1e3))

However, they don’t quite work. The first error I am getting is:

Traceback (most recent call last):
File “matrixConvolution.py”, line 99, in
rcidx, rhidx, rwidx = schedule[OutputMatrix].op.reduce_axis
ValueError: not enough values to unpack (expected 3, got 0)

Please help me with the implementation. Also, if you know a better tutorial, either in TVM or CUDA, please share it with me. Thank you.

If all you need is to get conv2d working on TVM, you can directly use the conv2d op that has been defined and optimized. For example, the following code takes data (NCHW) and weight (OIHW), executes a conv2d with stride (2, 2), and produces output (NCHW). Since the target is set to cuda, it will automatically use the well-defined schedule for CUDA on GPU. If you would like to use cuBLAS/cuDNN, you can simply change the target to cuda -libs=cublas,cudnn.

import numpy as np
from tvm import relay
from tvm.contrib import graph_runtime

data = relay.var('data', shape=(1, 3, 224, 224))
weight = relay.var('weight', shape=(32, 3, 3, 3))
out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
fun = relay.Function([data, weight], out)
mod = relay.Module()
mod['main'] = fun
target = tvm.target.cuda()
with relay.build_config(opt_level=3):
    graph, lib, params = relay.build_module.build(mod, target, params={})

ctx = tvm.gpu()
i_data = np.random.uniform(0, 1, size=(1, 3, 224, 224)).astype('float32')
i_weight= np.random.uniform(0, 1, size=(32, 3, 3, 3)).astype('float32')

module = graph_runtime.create(graph, lib, ctx)
module.set_input('data', i_data)
module.set_input('weight', i_weight)
module.run()
out = module.get_output(0, tvm.nd.empty((1, 32, 112, 112))).asnumpy()

I am testing a new way of doing convolution, so I really need to have an implementation for that. TVM Tensor Expression hides a lot of complexity, which is great. I just can’t understand why I am getting that error. Why is the reduce_axis member none?

Cutlass open sourced NHWC 2D convolution. It is not difficult to change to use NCHW. Feel free to ask questions on cutlass github.