How to avoid allocating the wrong amount of registers in CUDA scheduling in this example?

Hello, I found sometimes registers are allocated more than expected in CUDA code generation. It usually happens when data is cached in registers and read to shared memory chunk by chunk.

Here’s a simple example resembling an actual problem I recently encountered. I’m aiming at reading the whole A tensor and caching the data in registers, then calculate B in the shared memory through blocking with this data, and write to C as the output.

import tvm

def schedule(A, B, C):
    s = tvm.create_schedule(C.op)
    
    AA = s.cache_read(A, "local", [B])
    s[B].set_scope("shared")
    block_x = tvm.thread_axis("blockIdx.x")
    thread_x = tvm.thread_axis((0, 32), "threadIdx.x")

    oc, ic = s[C].split(s[C].op.axis[0], factor=64)
    ooc, ioc = s[C].split(oc, factor=2)
    oic, iic = s[C].split(ic, factor=32)
    s[C].bind(ooc, block_x)
    s[C].bind(iic, thread_x)

    s[B].compute_at(s[C], ioc)
    ob, ib = s[B].split(s[B].op.axis[0], factor=32)
    s[B].bind(ib, thread_x)

    s[AA].compute_root()
    s[AA].compute_at(s[C], ooc)
    oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32)
    s[AA].bind(iaa, thread_x)

    return s

def test():
    A = tvm.placeholder((128,), name="A")
    B = tvm.compute((128,), lambda i: A[i] + 1, name="B")
    C = tvm.compute((128,), lambda i: B[i] + 2, name="C")

    device = "cuda"
    ctx = tvm.context(device, 0)
    with tvm.target.create(device):
        s = schedule(A, B, C)

    print(tvm.lower(s, [A, B, C], simple_mode=True))
    func = tvm.build(s, [A, B, C], device, name=("test"))
    print(func.imported_modules[0].get_source())

if __name__ == "__main__":
    test()

However, in the result schedule here:

produce C {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
  // attr [A.local] storage_scope = "local"
  allocate A.local[float32 * 97]
  produce A.local {
    for (ax0.outer, 0, 4) {
      // attr [iter_var(threadIdx.x, range(min=0, ext=32), threadIdx.x)] thread_extent = 32
      if (likely((((ax0.outer*32) + threadIdx.x) < 97))) {
        if (likely((((ax0.outer*32) + (threadIdx.x*2)) < 128))) {
          A.local[((ax0.outer*32) + threadIdx.x)] = A[((ax0.outer*32) + (threadIdx.x*2))]
        }
      }
    }
  }
  for (i.outer.inner, 0, 2) {
    produce B {
      for (i.outer, 0, 2) {
        // attr [iter_var(threadIdx.x, range(min=0, ext=32), threadIdx.x)] thread_extent = 32
        B[(((i.outer.inner*64) + (i.outer*32)) + threadIdx.x)] = (A.local[((i.outer.inner*64) + (i.outer*32))] + 1f)
      }
    }
    for (i.inner.outer, 0, 2) {
      // attr [iter_var(threadIdx.x, range(min=0, ext=32), threadIdx.x)] thread_extent = 32
      C[(((i.outer.inner*64) + (i.inner.outer*32)) + threadIdx.x)] = (B[(((i.outer.inner*64) + (i.inner.outer*32)) + threadIdx.x)] + 2f)
    }
  }
}

I find that the compiler allocates 97 registers for A.local which is not what I want. In fact I expect it to allocate only 4 registers per thread which is enough to carry out the computation correctly according to the schedule. What I expect is that each thread reads 4 elements of A, stores them in registers, calculates 2 elements of B and C each time, and repeats that twice.

I think the number 97 comes from: it needs 97 registers to hold A[0], A[32], A[64] and A[96] if the whole chunk of data from A[0] to A[96] is stored in thread 0 (and so on for other threads).

I have encountered a similar problem before, and I solve it by binding the axis ax0.outer to vthread_x. But this time I need to keep the for loop of ax0.outer so vthread is out of the question.

Anyone can help?

1 Like

Anyone has any thoughts on this problem? I was thinking if this could possibly because TVM doesn’t support scheduling with caching in registers? Or it does but I didn’t do that in a correct way? Thanks!

@masahi Can you take a look at this? Thanks!

I just added a few lines for verification. Your code ends with wrong results. I will take a close look.

from __future__ import absolute_import, print_function
import tvm
import numpy as np

def schedule(A, B, C):
    s = tvm.create_schedule(C.op)
    
    AA = s.cache_read(A, "local", [B])
    s[B].set_scope("shared")
    block_x = tvm.thread_axis("blockIdx.x")
    thread_x = tvm.thread_axis((0, 32), "threadIdx.x")

    oc, ic = s[C].split(s[C].op.axis[0], factor=64)
    ooc, ioc = s[C].split(oc, factor=2)
    oic, iic = s[C].split(ic, factor=32)
    s[C].bind(ooc, block_x)
    s[C].bind(iic, thread_x)

    s[B].compute_at(s[C], ioc)
    ob, ib = s[B].split(s[B].op.axis[0], factor=32)
    s[B].bind(ib, thread_x)

    s[AA].compute_root()
    s[AA].compute_at(s[C], ooc)
    oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32)
    s[AA].bind(iaa, thread_x)

    return s

def test():
    A = tvm.placeholder((128,), name="A")
    B = tvm.compute((128,), lambda i: A[i] + 1, name="B")
    C = tvm.compute((128,), lambda i: B[i] + 2, name="C")

    device = "cuda"
    ctx = tvm.context(device, 0)
    with tvm.target.create(device):
        s = schedule(A, B, C)

    print(tvm.lower(s, [A, B, C], simple_mode=True))
    func = tvm.build(s, [A, B, C], device, name=("test"))
    print(func.imported_modules[0].get_source())

    from tvm.contrib import tedd
    tedd.viz_dataflow_graph(s, False, '/tmp/dfg.dot')
    tedd.viz_schedule_tree(s, False, '/tmp/scheduletree.dot')
    tedd.viz_itervar_relationship_graph(s, False, '/tmp/itervar.dot')

    a = tvm.nd.array(np.random.rand(A.shape[0].value, ).astype("float32"), ctx)
    b = tvm.nd.array(np.random.rand(B.shape[0].value, ).astype("float32"), ctx)
    c = tvm.nd.array(np.random.rand(C.shape[0].value, ).astype("float32"), ctx)

    func(a, b, c)
    result = c.asnumpy()

    answer = a.asnumpy() + 3
    tvm.testing.assert_allclose(result, answer, rtol=1e-5)
    evaluator = func.time_evaluator(func.entry_name, ctx, number=1)
    print(func.entry_name+': %f ms' % (evaluator(a, b, c).mean * 1e3))

if __name__ == "__main__":
    test()

Thank you for your time looking into this! Please let me know if you have figured out anything.

Two issues here:

  1. Wrong CUDA code for wrong numerical result.
  2. Allocating more local memory than necessary.

The first one is caused by wrong relaxation decision during InferBound. In this example, threadIdx.x are bound to all three stages and all under the attaching points. InferBound, NeedRelax() in particular, decides NOT to relax when it sees ax0.inner in A.local shown in the following schedule tree. Therefore it adds “if” statements to retain the wrong inferred region.

This is probably a bug. I have a quick fix prototype (https://github.com/yongfeng-nv/incubator-tvm/commit/b958a13188e055273da5f939938802332ebbbf4e). It generates the following IR and code. It has the correct results, but infers more local memory.

    produce C {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
  // attr [A.local] storage_scope = "local"
  allocate A.local[float32 * 128]
  produce A.local {
    for (ax0.outer, 0, 4) {
      // attr [iter_var(threadIdx.x, range(min=0, ext=32), threadIdx.x)] thread_extent = 32
      A.local[((ax0.outer*32) + threadIdx.x)] = A[((ax0.outer*32) + threadIdx.x)]
    }
  }
  for (ci.outer.inner, 0, 2) {
    produce B {
      for (bi.outer, 0, 2) {
        // attr [iter_var(threadIdx.x, range(min=0, ext=32), threadIdx.x)] thread_extent = 32
        B[(((ci.outer.inner*64) + (bi.outer*32)) + threadIdx.x)] = (A.local[(((ci.outer.inner*64) + (bi.outer*32)) + threadIdx.x)] + 1f)
      }
    }
    for (ci.inner.outer, 0, 2) {
      // attr [iter_var(threadIdx.x, range(min=0, ext=32), threadIdx.x)] thread_extent = 32
      C[(((ci.outer.inner*64) + (ci.inner.outer*32)) + threadIdx.x)] = (B[(((ci.outer.inner*64) + (ci.inner.outer*32)) + threadIdx.x)] + 2f)
    }
  }
}

extern "C" __global__ void test_kernel0( float* __restrict__ A,  float* __restrict__ B,  float* __restrict__ C) {
   float A_local[128];
  for (int ax0_outer = 0; ax0_outer < 4; ++ax0_outer) {
    A_local[(((ax0_outer * 32) + ((int)threadIdx.x)))] = A[(((ax0_outer * 32) + ((int)threadIdx.x)))];
  }
  for (int ci_outer_inner = 0; ci_outer_inner < 2; ++ci_outer_inner) {
    for (int bi_outer = 0; bi_outer < 2; ++bi_outer) {
      B[((((ci_outer_inner * 64) + (bi_outer * 32)) + ((int)threadIdx.x)))] = (A_local[((((ci_outer_inner * 64) + (bi_outer * 32)) + ((int)threadIdx.x)))] + 1.000000e+00f);
    }
    for (int ci_inner_outer = 0; ci_inner_outer < 2; ++ci_inner_outer) {
      C[((((ci_outer_inner * 64) + (ci_inner_outer * 32)) + ((int)threadIdx.x)))] = (B[((((ci_outer_inner * 64) + (ci_inner_outer * 32)) + ((int)threadIdx.x)))] + 2.000000e+00f);
    }
  }
}

The second issue about the local memory size is also related to InferBound. Here is my understanding and would like to have folks w/ more knowledge to comment on it.

 allocate A.local[float32 * 128]

The size 128 comes the factor that A.local is relaxed due to the prototype fix to the first issue. However, there are other underline issues.

A.local[((ax0.outer*32) + threadIdx.x)] = A[((ax0.outer*32) + threadIdx.x)]

Ideally, when a thread accesses local memory, e.g. A.local, the index to the local memory shall not contain threadIdx.x, although it needs threadIdx.x to access shared/global memory, e.g. A. I am afraid that InferBound only provides bounds to the shared/global memory use cases, but not the local memory ones.

Changing “local” to “warp”

AA = s.cache_read(A, "warp", [B])

gets the following IR and code:

produce C {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
  // attr [A.warp] storage_scope = "warp"
  allocate A.warp[float32 * 128]
  produce A.warp {
    for (ax0.outer, 0, 4) {
      // attr [iter_var(threadIdx.x, range(min=0, ext=32), threadIdx.x)] thread_extent = 32
      A.warp[((ax0.outer*32) + threadIdx.x)] = A[((ax0.outer*32) + threadIdx.x)]
    }
  }
  for (ci.outer.inner, 0, 2) {
    produce B {
      for (bi.outer, 0, 2) {
        // attr [iter_var(threadIdx.x, range(min=0, ext=32), threadIdx.x)] thread_extent = 32
        B[(((ci.outer.inner*64) + (bi.outer*32)) + threadIdx.x)] = (A.warp[(((ci.outer.inner*64) + (bi.outer*32)) + threadIdx.x)] + 1f)
      }
    }
    for (ci.inner.outer, 0, 2) {
      // attr [iter_var(threadIdx.x, range(min=0, ext=32), threadIdx.x)] thread_extent = 32
      C[(((ci.outer.inner*64) + (ci.inner.outer*32)) + threadIdx.x)] = (B[(((ci.outer.inner*64) + (ci.inner.outer*32)) + threadIdx.x)] + 2f)
    }
  }
}

extern "C" __global__ void test_kernel0( float* __restrict__ A,  float* __restrict__ B,  float* __restrict__ C) {
   float A_warp[4];
  for (int ax0_outer = 0; ax0_outer < 4; ++ax0_outer) {
    A_warp[(ax0_outer)] = A[(((ax0_outer * 32) + ((int)threadIdx.x)))];
  }
  for (int ci_outer_inner = 0; ci_outer_inner < 2; ++ci_outer_inner) {
    for (int bi_outer = 0; bi_outer < 2; ++bi_outer) {
      B[((((ci_outer_inner * 64) + (bi_outer * 32)) + ((int)threadIdx.x)))] = (A_warp[((((ci_outer_inner * 64) + (bi_outer * 32)) + ((int)threadIdx.x)))] + 1.000000e+00f);
    }
    for (int ci_inner_outer = 0; ci_inner_outer < 2; ++ci_inner_outer) {
      C[((((ci_outer_inner * 64) + (ci_inner_outer * 32)) + ((int)threadIdx.x)))] = (B[((((ci_outer_inner * 64) + (ci_inner_outer * 32)) + ((int)threadIdx.x)))] + 2.000000e+00f);
    }
  }
}

The first stage and local memory size are correct:

   float A_warp[4];
  for (int ax0_outer = 0; ax0_outer < 4; ++ax0_outer) {
    A_warp[(ax0_outer)] = A[(((ax0_outer * 32) + ((int)threadIdx.x)))];
  }

But the second stage is still not:

      B[((((ci_outer_inner * 64) + (bi_outer * 32)) + ((int)threadIdx.x)))] = (A_warp[((((ci_outer_inner * 64) + (bi_outer * 32)) + ((int)threadIdx.x)))] + 1.000000e+00f);

It should be:

      B[((((ci_outer_inner * 64) + (bi_outer * 32)) + ((int)threadIdx.x)))] = (A_warp[ci_outer_inner * 2 + bi_outer] + 1.000000e+00f);

@tqchen, is “warp” a working feature? I only found a few unit tests using “warp”. I converted one to verify the numerical result, but it failed for the same reason.

I feel like the “warp” memory here refers to the registers used in warp shuffle or warp reduction. Evidence:

https://sourcegraph.com/github.com/apache/incubator-tvm@62424611c82a5be30913d17fcbcfef26506d328c/-/blob/src/runtime/thread_storage_scope.h#L48

I think what we are trying to do here is supposed to be done by “local” + virtual threads, although more fixes might be needed. @tqchen Can you give us more insights here?

Any advice here? @FrozenGene