[HalideIR] Enhance For

There are some limitations of For now

  1. start must be 0, which is limited by codegen_c
  2. step of iterator always is 1

I’d suggested to enhance For to support this pattern:

__global__
void saxpy(int n, float a, float *x, float *y)
{
    for (int i = blockIdx.x * blockDim.x + threadIdx.x; 
         i < n; 
         i += blockDim.x * gridDim.x) 
      {
          y[i] = a * x[i] + y[i];
      }
}

@tqchen Do you think it is a good idea to support this kind of for expression in HalideIR? It’ll be helpful when we write some kernel with low level API

1 Like

Is this really the case? codegen_c is only one of the many “backends” of TVM.
If I remember correctly, the For loops are normalized in the tvm.schedule.normalize() and it is (AFAIK) a simplification so that InferBound is easier.

Yes, there is an assert there to make sure min starts at 0
https://github.com/dmlc/tvm/blob/b63267b92d942b9c64f814b73567b2fe908e67fb/src/codegen/codegen_c.cc#L826

What I mean is that it is not the place to put start to 0, but this check makes non-zero start not to work.

If ir_builder is used, there seems no normalization for this

Most likely we can use the same normalized loop to represent the same program, and low level program optimizer will detect such loop and rewrites to the strided version

for ( int i = 0; i < extent; i ++) {
   y[i * stride + min] = a * x[i * stride + min]
}

Would you mind giving your insights as to which “low level program optimizers” do this.

  1. So do the LLVM mid-backends generally do this?
  2. What about if we want to design a backend for a non-LLVM compilable HW target?

I am thinking that since the problem is due to normalization (required for other TVM routines), TVM could “denormalize” after.

hi I check tvm source code,I never see blockDim.x,and gridDim.x I think loop fused is ok when loop start,cond and step is same.