[TVM IR] How to get the lhs IR representation?


#1

When we do GEMM, the inner loop will be like this:

C[ramp((((x.inner*8) + y.outer)*128), 1, 128)] = (C[ramp((((x.inner*8) + y.outer)*128), 1, 128)] + (x128(A[((x.inner*8) + k)])*B[ramp(((y.outer + (k*8))*128), 1, 128)]))

When we do VisitExpr, we could only get Mul / Add of RHS. i.e. (C[ramp((((x.inner*8) + y.outer)128), 1, 128)], + , (x128(A[((x.inner8) + k)])B[ramp(((y.outer + (k8))*128), 1, 128)])).

I want to know whether we have one way to get C[ramp((((x.inner*8) + y.outer)*128), 1, 128)] in the left hand side.

Asking the question, just because I want to check the IR pattern of GEMV and do something. From the doc: https://docs.tvm.ai/tutorials/language/tensorize.html?highlight=tensorize I know we could recognize this pattern using tensorize. But I want to know in the CodeGen part, whether we could check the pattern?

@tqchen @yzhliu @vinx13 and anyone could answer this question? Thanks.


#2

I think you can get the LHS by having the VisitExpr on Store or Provide node.

and maybe even a Call node.

In my understanding, the whole expression is a Provide node.

C[ramp((((x.inner*8) + y.outer)*128), 1, 128)] = (C[ramp((((x.inner*8) + y.outer)*128), 1, 128)] + (x128(A[((x.inner*8) + k)])*B[ramp(((y.outer + (k*8))*128), 1, 128)]))

For a simple exmple:

  a = tvm.placeholder(shape, dtype, name)
  b = tvm.placeholder(shape, dtype, name)
  c = tvm.compute(shape, lambda, name)

Let’s say halideIR is as following:

             c[i, j] = a[i, j] + b[i, j]

here, the all of above stmt is a Provide Stmt of the c.

Provide stmt has following attributes:

  1. FunctionRef func: the operation through which c was created. Here it is ComputeOpNode
  2. int value_index: if the OperationNodes returns multiple outputs, this is the index of c
  3. Expr value: Expr through which c is computed here, that Expr is a[i,j]+b[i,j]
  4. Array args: indexing into the c for the computation. here it is [i, j]

Name for the provide is c can be obtained by c->func->name. Returns the name we provided while
defining the computeOpNode for the c.

Call is an Expr.

In the above example, the c[i, j] by itself is a Call node of type Halide.

Call node has these attributes:

  1. Type type: data type
  2. String name: here it is c
  3. CallType call_type: could be any of the 5 call_types: halide, intrinsic, pure_intrinsic, extern, pure_extern etc.
    In the example, it is of type Halide.
  4. FunctionRef func: again the same as in Provide. ComputeOpNode for c[i, j] and PlaceHolderOp for a[i, j] and b[i, j]
  5. int value_index: same as in Provide
  6. Array args: Indexing into for the call. Same as in provide.

Provide stmt : c[i, j] = a[i, j] + b[i, j] can be decomposed as :

 {Call -> c[i, j]} = Add(Call -> a[i, j], Call -> B[i, j])