How to avoid unused memory allocation when `cached_read` is bound to blockIdx

TL;DR. The problem is the data actually could fit into shared memory if each block hold 1/k of them, where k is the number of blocks. However, a cached_read followed by split + bind to blockIdx results in k times more memory allocated.

Background. Often in CUDA, blocks may cooperatively load slices of data into their own shared memory, it is especially useful when their workload is independent of each other. So, in TVM, we may want to load shards (or a slice) of an array into the on-chip memory, instead the whole. It is (imho) critical step when we want to construct persistent RNN (link).

Why not load them all. Shared memory is scarce (96 KB on V100). A square FP32 weight matrix of 256 * 256 could be 256 KB. This results in runtime error.

Why not compute_at. This results in unacceptably unnecessary memory access. When doing GEMM, it is good to put the cache_read when they are actually needed using compute_at. However, in a persistent RNN, GEMM is invoked literally on every time step. If we use compute_at to move the load to GEMM computation, the data will be repeatedly read for many many times, which causes unnecessary waste of memory bandwidth.

Example. Let’s assume m, n, k are some known constants, so that TVM is able to allocate static memory. Let w be an (m * n) array, and there are k blocks on GPU. WLOG, say m is divisible by k. In blocks i, we may want to load only w[m / k * i : m / k * (i + 1), :] into the processor’s on-chip memory.

### Python code
wS = s.cache_read(w, "shared", readers=readers)
m, _ = s[wS].op.axis
bx, _ = s[wS].split(m, nparts=num_sm)
s[wS].bind(bx, blockIdx.x)

### However, it translates to
allocate w.shared[float32 * m * n]

### Instead, what we actually need is:
allocate w.shared[float32 * (m / k) * n]  # k is the number of SMs

I am relatively new to the scheduling sub-language in TVM. Any suggestions? Thanks in advance!

The information provided here is too partial for me to give some essential suggestions. Can you provide some context, at least from the shared load allocation level to the computation body.

@were Thank you for the quick response.

Let’s take vanilla RNN with tanh activation as the example, the pseudocode looks like:

Input: x[seq_len, batch_size, input_dim]
Output: h[seq_len, batch_size, hidden_dim]
Weights: w_i2h[hidden_dim, input_dim]
         w_h2h[hidden_dim, hidden_dim]
         (weights have been transposed, biases are omitted)

# In this example, the number of blocks is equal to the number of SMs
bx = blockIdx.x  # the index of the SM we are currently at 
st = hidden_dim / num_sm * i
ed = hidden_dim / num_sm * (i + 1) # say divisible

# Persist the weights into shared memory
w_i2h.shared <- w_i2h[st:ed, :]
w_h2h.shared <- w_h2h[st:ed, :]

# initialize hidden state
h[0, : , st:ed] = 0
some_global_barrier();

# the main loop
for scan.idx in range(1, seq_len):
    # computation inside the RNN cell, each SM could do this independently
    s_i2h = linear(w_i2h, x[scan.idx - 1,:,:])
    s_h2h = linear(w_h2h, h[scan.idx - 1,:,:])
    next_h = tanh(s_i2h + h2h)
    # SMs cooperatively calcuate h[scan.idx,:,:]
    h[scan.idx, : , st:ed] = next_h
    some_global_barrier();

If we do compute_at on any of the axis inside this loop, won’t it cause too many unnecessary reads?

Therefore, I do compute_at onside the loop. However, it causes too much memory allocation because TVM automatically create a shared array of the same size of w_i2h and w_h2h, as we mentioned in the first post in the thread.

Can you try to do explicit packing?

cache_read can only fetch a contiguous range. But actually cache_read is not necessary. We can create a new stage by using a custom tvm.compute and tag its memory scope as shared memory. In this compute, we can fetch the only necessary data.
This requires rewrite of tvm.compute. I don’t know whether it is easy for your case.
I used it for dilation (https://github.com/dmlc/tvm/issues/1887#issuecomment-434568924)

Thank you for the suggestion! Will try it out!

what’s the final solution for this question?
@junrushao