TVM access beyond array boundary

Hi all,

I am trying to build a SpMM kernel as following:

import tvm
from tvm import te
import scipy
import scipy.sparse

feat_len = 128
num_rows = num_cols = 253
num_threads_per_block = 64
num_cuda_blocks = 127

SrcFeat = te.placeholder((num_cols, feat_len))

adj_scipy_csr = scipy.sparse.random(num_rows, num_cols, density=0.1, format='csr').astype('float32')
adj_indptr = adj_scipy_csr.indptr
adj_indices = adj_scipy_csr.indices
adj_vals = adj_scipy_csr.data
adj_indptr_placeholder = te.placeholder(shape=adj_indptr.shape, \
          dtype=str(adj_indptr.dtype), name='adj_indptr_placeholder')
adj_indices_placeholder = te.placeholder(shape=adj_indices.shape, \
          dtype=str(adj_indices.dtype), name='adj_indices_placeholder')
adj_vals_placeholder = te.placeholder(shape=adj_vals.shape, \
          dtype=str(adj_vals.dtype), name='adj_vals_placeholder')

def msgfunc(row, ff):
        row_start = adj_indptr_placeholder[row]
        row_end = adj_indptr_placeholder[row + 1]
        row_num_elems = row_end - row_start
        elem_idx = te.reduce_axis((0, row_num_elems), name="elem_idx")
        adj_val = adj_vals_placeholder[row_start + elem_idx]
        feat_val = SrcFeat[adj_indices_placeholder[row_start + elem_idx], ff]
        return te.sum(adj_val * feat_val, axis=elem_idx)
Out = te.compute((num_rows, feat_len), msgfunc, name='Out')
s = te.create_schedule([Out.op])
row_axis = Out.op.axis[0]
feat_axis = Out.op.axis[1]
row_outer, row_inner = s[Out.op].split(row_axis, nparts=num_cuda_blocks)
feat_outer, feat_inner = s[Out.op].split(feat_axis, factor=num_threads_per_block)
s[Out.op].reorder(feat_outer, row_outer, feat_inner, row_inner)
s[Out.op].bind(feat_outer, te.thread_axis("blockIdx.y"))
s[Out.op].bind(row_outer, te.thread_axis("blockIdx.x"))
s[Out.op].bind(feat_inner, te.thread_axis("threadIdx.x"))
out_placeholder = te.placeholder((num_rows, feat_len), dtype=str(adj_vals.dtype), name="out")
f = tvm.build(s, [adj_indptr_placeholder, adj_indices_placeholder, adj_vals_placeholder, SrcFeat, out_placeholder], target='cuda')
print(f.imported_modules[0].get_source())

And here is the generated kernel:

extern "C" __global__ void default_function_kernel0(float* __restrict__ Out, void* __restrict__ adj_indptr_placeholder, void* __restrict__ adj_vals_placeholder, void* __restrict__ placeholder, void* __restrict__ adj_indices_placeholder) {
  for (int row_inner = 0; row_inner < 2; ++row_inner) {
    if (((((int)blockIdx.x) * 2) + row_inner) < 253) {
      Out[(((((((int)blockIdx.x) * 256) + (row_inner * 128)) + (((int)blockIdx.y) * 64)) + ((int)threadIdx.x)))] = 0.000000e+00f;
    }
    for (int elem_idx = 0; elem_idx < (((int*)adj_indptr_placeholder)[((((((int)blockIdx.x) * 2) + row_inner) + 1))] - ((int*)adj_indptr_placeholder)[(((((int)blockIdx.x) * 2) + row_inner))]); ++elem_idx) {
      if (((((int)blockIdx.x) * 2) + row_inner) < 253) {
        Out[(((((((int)blockIdx.x) * 256) + (row_inner * 128)) + (((int)blockIdx.y) * 64)) + ((int)threadIdx.x)))] = (Out[(((((((int)blockIdx.x) * 256) + (row_inner * 128)) + (((int)blockIdx.y) * 64)) + ((int)threadIdx.x)))] + (((float*)adj_vals_placeholder)[((((int*)adj_indptr_placeholder)[(((((int)blockIdx.x) * 2) + row_inner))] + elem_idx))] * ((float*)placeholder)[((((((int*)adj_indices_placeholder)[((((int*)adj_indptr_placeholder)[(((((int)blockIdx.x) * 2) + row_inner))] + elem_idx))] * 128) + (((int)blockIdx.y) * 64)) + ((int)threadIdx.x)))]));
      }
    }
  }
}

TVM succeeds to prevent illegal access of the Out array, but fails to do so with adj_indptr. The length of Out is 253, while that of adj_indptr is 254. The last block has blockIdx.x=126, so in the condition of the elem loop adj_indptr[254] is accessed, which is beyond limit.

It seems like TVM does not know about the length of adj_indptr, why is it so? How should I fix this?

Best Regards

@kira-lin You are right, the two lines of code should be switched:

for (int elem_idx = 0; elem_idx < (((int*)adj_indptr_placeholder)[((((((int)blockIdx.x) * 2) + row_inner) + 1))] - ((int*)adj_indptr_placeholder)[(((((int)blockIdx.x) * 2) + row_inner))]); ++elem_idx) {
    if (((((int)blockIdx.x) * 2) + row_inner) < 253) {

A quick workaround is to disable row partitioning, that is, we set num_cuda_blocks to be num_rows. This schedule gives good performance in practice.