Expr Simplifier for tvm.var


#1

Hi,

I got an warning while doing a symbolic bound loop partition:

Cannot prove: ((((((n + 1)/2) - 1) - (((n - 4)/2) + 1)) + 1) >= 0), when generating the post doubt loop

This expression should be true. However, neither rewrite_simplify nor canonical_simplify can prove it. Did I miss something here?

@tqchen


#2

I have also noticed that sometimes, Simplifier is not able to simplify expressions in the iteration variable.
e.g. i == -1 is always false.

@kevinthesun I am using SuperSimplify as in here with vrange passed to it. This wouldn’t solve the cause of the problem but you can give it a try.


#3

Thank you for this information! I can’t find this pass in master. Is it moved to somewhere else?


#4

I created a min-sample for symbolic expr issue:

import tvm
import topi

dshape = (tvm.var("n"), 72, 96)
target = "cuda"

def compute(data):
    oshape = data.shape
    out = tvm.compute(oshape, lambda i, j, k: data[i, j, k] * 10)
    return out

def schedule(s, out):
    n, m, _ = s[out].op.axis
    bn_z, n = s[out].split(n, 32)
    bn_y, bn_x = s[out].split(n, 8)

    tm_z, m = s[out].split(m, 12)
    tm_y, tm_x = s[out].split(m, 1)

    s[out].bind(bn_z, tvm.thread_axis("blockIdx.z"))
    s[out].bind(bn_y, tvm.thread_axis("blockIdx.y"))
    s[out].bind(bn_x, tvm.thread_axis("blockIdx.x"))

    s[out].bind(tm_z, tvm.thread_axis("threadIdx.z"))
    s[out].bind(tm_y, tvm.thread_axis("threadIdx.y"))
    s[out].bind(tm_x, tvm.thread_axis("threadIdx.x"))
    return s


d = tvm.placeholder(dshape, name="data")
out = compute(d)
s = tvm.create_schedule(out.op)
s = schedule(s, out)
f = tvm.build(s, [d, out], target)

Lower stmt printed in tvm.build:

produce compute {
  // attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = ((n + 31)/32)
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 4
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 8
  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 6
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 12
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1
  if (((blockIdx.z < (((n - 32)/32) + 1)) && (blockIdx.z < (((n + 31)/32) - 1)))) {
    for (k, 0, 96) {
      compute[((((((blockIdx.z*221184) + (blockIdx.y*55296)) + (blockIdx.x*6912)) + (threadIdx.z*1152)) + (threadIdx.y*96)) + k)] = (data[((((((blockIdx.z*221184) + (blockIdx.y*55296)) + (blockIdx.x*6912)) + (threadIdx.z*1152)) + (threadIdx.y*96)) + k)]*10f)
    }
  } else {
    for (k, 0, 96) {
      if (((((blockIdx.z*32) + (blockIdx.y*8)) + blockIdx.x) < n)) {
        if (((((blockIdx.z*32) + (blockIdx.y*8)) + blockIdx.x) < n)) {
          compute[((((((blockIdx.z*221184) + (blockIdx.y*55296)) + (blockIdx.x*6912)) + (threadIdx.z*1152)) + (threadIdx.y*96)) + k)] = (data[((((((blockIdx.z*221184) + (blockIdx.y*55296)) + (blockIdx.x*6912)) + (threadIdx.z*1152)) + (threadIdx.y*96)) + k)]*10f)
        }
      }
    }
  }
}

The two problems are: 1) if (((blockIdx.z < (((n - 32)/32) + 1)) && (blockIdx.z < (((n + 31)/32) - 1)))) is not simplified. 2) Several if statements are under for loop of k, which can be moved up to reduce the number of executions.

@tqchen Would you think this can be improved by simplifier, or other parts of tvm?


#5

This is likely due to the fact that because the simplifier was not given the bound information, note that some of the cases depends on the division semantics, upgrading some of them to floordiv, which we have not yet done yet, might help some of the cases.


#6

Yes, it is not on the master, since it is not merged. It is simply three calls to simplify. You can get the implementation on the permalink I posted above.

CanonicalSimplify(Simplify(CanonicalSimplify(stmt, vrange), vrange, vrange))) . I am not sure with if with new arith infra this is any better.

So, maybe as Tianqi said, you can first try passing bound info to simplifier. or then may be try that.


#7

Do you think the second issue regrading the position of if statement is also related to expression simplifier? In more complicated symbolic expression such as conv2d, this is the major performance bottleneck.


#8

@tqchen Do you think we need to add loop invariant optimization pass to deal with this issue?