[topi][winograd] test_topi_conv2d_winograd.py fails for a given shape

test_topi_conv2d_winograd.py fails if I add the following test:

verify_conv2d_nchw(1, 128, 64, 64, 3, 1, 1)

The error information is like the following:

Workload: (1, 128, 64, 64, 3, 1, 1, 1)
Running on target: cuda
[21:23:45] /home/ubuntu/tvm/src/pass/loop_partition.cc:545: Cannot prove: ((((((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*16) + (threadIdx.x % 16)) + 1) - (((blockIdx.x*128) + threadIdx.x) % 256)) - 1) - 1) + 1) >= 0), when generating the post doubt loop
Traceback (most recent call last):

  File "topi/tests/python/test_topi_conv2d_winograd.py", line 110, in <module>
    test_conv2d_nchw()

  File "topi/tests/python/test_topi_conv2d_winograd.py", line 107, in test_conv2d_nchw
    verify_conv2d_nchw(1, 128, 64, 64, 3, 1, 1)

  File "topi/tests/python/test_topi_conv2d_winograd.py", line 91, in verify_conv2d_nchw
    check_device(device)

  File "topi/tests/python/test_topi_conv2d_winograd.py", line 82, in check_device
    func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))

  File "/home/ubuntu/tvm/python/tvm/build_module.py", line 573, in build
    binds=binds)

  File "/home/ubuntu/tvm/python/tvm/build_module.py", line 416, in lower
    return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)

  File "/home/ubuntu/tvm/python/tvm/_ffi/_ctypes/function.py", line 210, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (3) /home/ubuntu/tvm/build/libtvm.so(TVMFuncCall+0x61) [0x7efd6dc8cc31]
  [bt] (2) /home/ubuntu/tvm/build/libtvm.so(std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::LoweredFunc (tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool)>::AssignTypedLambda<tvm::LoweredFunc (*)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool)>(tvm::LoweredFunc (*)(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)+0x120) [0x7efd6d5060b0]
  [bt] (1) /home/ubuntu/tvm/build/libtvm.so(tvm::ir::MakeAPI(tvm::Stmt, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::Array<tvm::NodeRef, void>, int, bool)+0x3cd5) [0x7efd6d7a0ae5]
  [bt] (0) /home/ubuntu/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7efd6d4c4382]
  File "/home/ubuntu/tvm/src/pass/make_api.cc", line 205
TVMError: Not all Vars are passed in api_args:  'threadIdx.x'  does not appear in api_args

I will dig more into it, but it would be very helpful if someone who is more familiar with winograd schedule can take a look. Thanks.

CC @merrymercy @cbalint13 @kevinthesun

One way to dig would be to if it is caused by certain passes, e.g. disable loop partitioner to see what is going on, or dump out the intermediate stmt

@tqchen Thanks. That’s something I am looking into now. Loop partitioner seems okay. I disabled it and the error remains.

I reduced it to the minimum number of needed passes:

      binds, arg_list = get_binds(args, binds)
      cfg = current_build_config()
      if isinstance(sch, schedule.Schedule):
          stmt = form_body(sch)
      stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
      stmt = ir_pass.InjectVirtualThread(stmt)
      stmt = ir_pass.StorageRewrite(stmt)
      return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)

The stmt printed after StorageRewrite is like the following:

allocate kernel_pack[float32 * 294912]
// attr [data_pack] storage_scope = "global"
allocate data_pack[float32 * 1179648]
// attr [bgemm] storage_scope = "global"
allocate bgemm[float32 * 589824]
produce kernel_pack {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 64
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 128
  unrolled (eps, 0, 6) {
    unrolled (nu, 0, 6) {
      kernel_pack[(((eps*49152) + (nu*8192)) + ((blockIdx.x*128) + threadIdx.x))] = 0f
      unrolled (r_kh, 0, 3) {
        unrolled (r_kw, 0, 3) {
          kernel_pack[(((eps*49152) + (nu*8192)) + ((blockIdx.x*128) + threadIdx.x))] = (kernel_pack[(((eps*49152) + (nu*8192)) + ((blockIdx.x*128) + threadIdx.x))] + ((W[(((((((blockIdx.x*128) + threadIdx.x) % 64)*1152) + ((((blockIdx.x*128) + threadIdx.x)/64)*9)) + (r_kh*3)) + r_kw)]*select((((eps % 6) == 5) && ((r_kh % 3) == 2)), 1f, select((((eps % 6) == 5) && ((r_kh % 3) == 1)), 0f, select((((eps % 6) == 5) && ((r_kh % 3) == 0)), 0f, select((((eps % 6) == 4) && ((r_kh % 3) == 2)), 0.266667f, select((((eps % 6) == 4) && ((r_kh % 3) == 1)), -0.133333f, select((((eps % 6) == 4) && ((r_kh % 3) == 0)), 0.0666667f, select((((eps % 6) == 3) && ((r_kh % 3) == 2)), -0.266667f, select((((eps % 6) == 3) && ((r_kh % 3) == 1)), -0.533333f, select((((eps % 6) == 3) && ((r_kh % 3) == 0)), -1.06667f, select((((eps % 6) == 2) && ((r_kh % 3) == 2)), 0.333333f, select((((eps % 6) == 2) && ((r_kh % 3) == 1)), 0.333333f, select((((eps % 6) == 2) && ((r_kh % 3) == 0)), 0.333333f, select((((eps % 6) == 1) && ((r_kh % 3) == 2)), -0.333333f, select((((eps % 6) == 1) && ((r_kh % 3) == 1)), 0.333333f, select((((eps % 6) == 1) && ((r_kh % 3) == 0)), -0.333333f, select((((eps % 6) == 0) && ((r_kh % 3) == 2)), 0f, select((((eps % 6) == 0) && ((r_kh % 3) == 1)), 0f, select((((eps % 6) == 0) && ((r_kh % 3) == 0)), 1f, 0f)))))))))))))))))))*select((((nu % 6) == 5) && ((r_kw % 3) == 2)), 1f, select((((nu % 6) == 5) && ((r_kw % 3) == 1)), 0f, select((((nu % 6) == 5) && ((r_kw % 3) == 0)), 0f, select((((nu % 6) == 4) && ((r_kw % 3) == 2)), 0.266667f, select((((nu % 6) == 4) && ((r_kw % 3) == 1)), -0.133333f, select((((nu % 6) == 4) && ((r_kw % 3) == 0)), 0.0666667f, select((((nu % 6) == 3) && ((r_kw % 3) == 2)), -0.266667f, select((((nu % 6) == 3) && ((r_kw % 3) == 1)), -0.533333f, select((((nu % 6) == 3) && ((r_kw % 3) == 0)), -1.06667f, select((((nu % 6) == 2) && ((r_kw % 3) == 2)), 0.333333f, select((((nu % 6) == 2) && ((r_kw % 3) == 1)), 0.333333f, select((((nu % 6) == 2) && ((r_kw % 3) == 0)), 0.333333f, select((((nu % 6) == 1) && ((r_kw % 3) == 2)), -0.333333f, select((((nu % 6) == 1) && ((r_kw % 3) == 1)), 0.333333f, select((((nu % 6) == 1) && ((r_kw % 3) == 0)), -0.333333f, select((((nu % 6) == 0) && ((r_kw % 3) == 2)), 0f, select((((nu % 6) == 0) && ((r_kw % 3) == 1)), 0f, select((((nu % 6) == 0) && ((r_kw % 3) == 0)), 1f, 0f))))))))))))))))))))
        }
      }
    }
  }
}
produce data_pack {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 256
  // attr [d] storage_scope = "local"
  allocate d[float32 * 36]
  // attr [data_pack.local] storage_scope = "local"
  allocate data_pack.local[float32 * 36]
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 128
  produce d {
    for (eps, 0, 6) {
      for (nu, 0, 6) {
        d[((eps*6) + nu)] = tvm_if_then_else((((((((((((blockIdx.x*128) + threadIdx.x) % 256)/16) % 16)*4) + eps) >= 1) && ((((((((blockIdx.x*128) + threadIdx.x) % 256)/16) % 16)*4) + eps) < 65)) && (((((((blockIdx.x*128) + threadIdx.x) % 256) % 16)*4) + nu) >= 1)) && (((((((blockIdx.x*128) + threadIdx.x) % 256) % 16)*4) + nu) < 65)), A[((((((((((blockIdx.x*128) + threadIdx.x) % 256)/256)*524288) + ((((blockIdx.x*128) + threadIdx.x)/16)*256)) + (eps*64)) + ((((blockIdx.x*128) + threadIdx.x) % 16)*4)) + nu) - 65)], 0f)
      }
    }
  }
  produce data_pack.local {
    unrolled (eps.c, 0, 6) {
      unrolled (nu.c, 0, 6) {
        data_pack.local[((eps.c*6) + nu.c)] = 0f
        unrolled (r_a, 0, 6) {
          unrolled (r_a, 0, 6) {
            data_pack.local[((eps.c*6) + nu.c)] = (data_pack.local[((eps.c*6) + nu.c)] + ((d[((r_a*6) + r_a)]*select((((r_a % 6) == 5) && ((eps.c % 6) == 5)), 1f, select((((r_a % 6) == 5) && ((eps.c % 6) == 4)), 0f, select((((r_a % 6) == 5) && ((eps.c % 6) == 3)), 0f, select((((r_a % 6) == 5) && ((eps.c % 6) == 2)), 0f, select((((r_a % 6) == 5) && ((eps.c % 6) == 1)), 0f, select((((r_a % 6) == 5) && ((eps.c % 6) == 0)), 0f, select((((r_a % 6) == 4) && ((eps.c % 6) == 5)), 1.5f, select((((r_a % 6) == 4) && ((eps.c % 6) == 4)), 1f, select((((r_a % 6) == 4) && ((eps.c % 6) == 3)), 1f, select((((r_a % 6) == 4) && ((eps.c % 6) == 2)), 1f, select((((r_a % 6) == 4) && ((eps.c % 6) == 1)), 1f, select((((r_a % 6) == 4) && ((eps.c % 6) == 0)), 1f, select((((r_a % 6) == 3) && ((eps.c % 6) == 5)), -2f, select((((r_a % 6) == 3) && ((eps.c % 6) == 4)), -0.5f, select((((r_a % 6) == 3) && ((eps.c % 6) == 3)), 2f, select((((r_a % 6) == 3) && ((eps.c % 6) == 2)), 2.5f, select((((r_a % 6) == 3) && ((eps.c % 6) == 1)), 0.5f, select((((r_a % 6) == 3) && ((eps.c % 6) == 0)), 1.5f, select((((r_a % 6) == 2) && ((eps.c % 6) == 5)), -1.5f, select((((r_a % 6) == 2) && ((eps.c % 6) == 4)), -1f, select((((r_a % 6) == 2) && ((eps.c % 6) == 3)), -1f, select((((r_a % 6) == 2) && ((eps.c % 6) == 2)), 0.5f, select((((r_a % 6) == 2) && ((eps.c % 6) == 1)), -2.5f, select((((r_a % 6) == 2) && ((eps.c % 6) == 0)), -2f, select((((r_a % 6) == 1) && ((eps.c % 6) == 5)), 1f, select((((r_a % 6) == 1) && ((eps.c % 6) == 4)), 0.5f, select((((r_a % 6) == 1) && ((eps.c % 6) == 3)), -2f, select((((r_a % 6) == 1) && ((eps.c % 6) == 2)), -1f, select((((r_a % 6) == 1) && ((eps.c % 6) == 1)), 1f, select((((r_a % 6) == 1) && ((eps.c % 6) == 0)), -1.5f, select((((r_a % 6) == 0) && ((eps.c % 6) == 5)), 0f, select((((r_a % 6) == 0) && ((eps.c % 6) == 4)), 0f, select((((r_a % 6) == 0) && ((eps.c % 6) == 3)), 0f, select((((r_a % 6) == 0) && ((eps.c % 6) == 2)), 0f, select((((r_a % 6) == 0) && ((eps.c % 6) == 1)), 0f, select((((r_a % 6) == 0) && ((eps.c % 6) == 0)), 1f, 0f)))))))))))))))))))))))))))))))))))))*select((((r_a % 6) == 5) && ((nu.c % 6) == 5)), 1f, select((((r_a % 6) == 5) && ((nu.c % 6) == 4)), 0f, select((((r_a % 6) == 5) && ((nu.c % 6) == 3)), 0f, select((((r_a % 6) == 5) && ((nu.c % 6) == 2)), 0f, select((((r_a % 6) == 5) && ((nu.c % 6) == 1)), 0f, select((((r_a % 6) == 5) && ((nu.c % 6) == 0)), 0f, select((((r_a % 6) == 4) && ((nu.c % 6) == 5)), 1.5f, select((((r_a % 6) == 4) && ((nu.c % 6) == 4)), 1f, select((((r_a % 6) == 4) && ((nu.c % 6) == 3)), 1f, select((((r_a % 6) == 4) && ((nu.c % 6) == 2)), 1f, select((((r_a % 6) == 4) && ((nu.c % 6) == 1)), 1f, select((((r_a % 6) == 4) && ((nu.c % 6) == 0)), 1f, select((((r_a % 6) == 3) && ((nu.c % 6) == 5)), -2f, select((((r_a % 6) == 3) && ((nu.c % 6) == 4)), -0.5f, select((((r_a % 6) == 3) && ((nu.c % 6) == 3)), 2f, select((((r_a % 6) == 3) && ((nu.c % 6) == 2)), 2.5f, select((((r_a % 6) == 3) && ((nu.c % 6) == 1)), 0.5f, select((((r_a % 6) == 3) && ((nu.c % 6) == 0)), 1.5f, select((((r_a % 6) == 2) && ((nu.c % 6) == 5)), -1.5f, select((((r_a % 6) == 2) && ((nu.c % 6) == 4)), -1f, select((((r_a % 6) == 2) && ((nu.c % 6) == 3)), -1f, select((((r_a % 6) == 2) && ((nu.c % 6) == 2)), 0.5f, select((((r_a % 6) == 2) && ((nu.c % 6) == 1)), -2.5f, select((((r_a % 6) == 2) && ((nu.c % 6) == 0)), -2f, select((((r_a % 6) == 1) && ((nu.c % 6) == 5)), 1f, select((((r_a % 6) == 1) && ((nu.c % 6) == 4)), 0.5f, select((((r_a % 6) == 1) && ((nu.c % 6) == 3)), -2f, select((((r_a % 6) == 1) && ((nu.c % 6) == 2)), -1f, select((((r_a % 6) == 1) && ((nu.c % 6) == 1)), 1f, select((((r_a % 6) == 1) && ((nu.c % 6) == 0)), -1.5f, select((((r_a % 6) == 0) && ((nu.c % 6) == 5)), 0f, select((((r_a % 6) == 0) && ((nu.c % 6) == 4)), 0f, select((((r_a % 6) == 0) && ((nu.c % 6) == 3)), 0f, select((((r_a % 6) == 0) && ((nu.c % 6) == 2)), 0f, select((((r_a % 6) == 0) && ((nu.c % 6) == 1)), 0f, select((((r_a % 6) == 0) && ((nu.c % 6) == 0)), 1f, 0f))))))))))))))))))))))))))))))))))))))
          }
        }
      }
    }
  }
  for (eps, 0, 6) {
    for (nu, 0, 6) {
      data_pack[(((eps*196608) + (nu*32768)) + ((blockIdx.x*128) + threadIdx.x))] = data_pack.local[((eps*6) + nu)]
    }
  }
}
produce bgemm {
  // attr [iter_var(eps.nu.fused.outer, )] pragma_auto_unroll_max_step = 0
  // attr [iter_var(eps.nu.fused.outer, )] pragma_unroll_explicit = 0
  // attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = 36
  // attr [bgemm.local] storage_scope = "local"
  allocate bgemm.local[float32 * 1]
  // attr [kernel_pack.shared] storage_scope = "shared"
  allocate kernel_pack.shared[float32 * 1]
  // attr [data_pack.shared] storage_scope = "shared"
  allocate data_pack.shared[float32 * 1]
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 64
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 256
  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1
  produce bgemm.local {
    bgemm.local[0] = 0f
    for (ci.outer, 0, 128) {
      produce kernel_pack.shared {
        // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
        // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
        // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1
        kernel_pack.shared[0] = kernel_pack[(((blockIdx.z*8192) + (ci.outer*64)) + blockIdx.y)]
      }
      produce data_pack.shared {
        // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
        // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
        // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1
        data_pack.shared[0] = data_pack[(((blockIdx.z*32768) + (ci.outer*256)) + blockIdx.x)]
      }
      bgemm.local[0] = (bgemm.local[0] + (kernel_pack.shared[0]*data_pack.shared[0]))
    }
  }
  bgemm[((((((0*16384) + (blockIdx.z*16384)) + (0*256)) + (blockIdx.y*256)) + 0) + blockIdx.x)] = bgemm.local[(((((((0 + blockIdx.z)/6) + 0) + 0) + ((0 + blockIdx.z) % 6)) - (blockIdx.z/6)) - (blockIdx.z % 6))]
}
produce output {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 128
  // attr [inverse] storage_scope = "local"
  allocate inverse[float32 * ((((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*16) + (threadIdx.x % 16)) + 1) - (((blockIdx.x*128) + threadIdx.x) % 256))*4)*4)]
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 128
  produce inverse {
    for (p, 0, ((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*16) + (threadIdx.x % 16)) + 1) - (((blockIdx.x*128) + threadIdx.x) % 256))) {
      unrolled (vh, 0, 4) {
        unrolled (vw, 0, 4) {
          inverse[(((p*16) + (vh*4)) + vw)] = 0f
          unrolled (r_a, 0, 6) {
            unrolled (r_a, 0, 6) {
              if (likely(((p + (((((blockIdx.x*8) + (threadIdx.x/16)) % 16)*16) + (threadIdx.x % 16))) < 256))) {
                inverse[(((p*16) + (vh*4)) + vw)] = (inverse[(((p*16) + (vh*4)) + vw)] + ((bgemm[((((p + threadIdx.x) + (r_a*98304)) + (r_a*16384)) + ((blockIdx.x*8)*16))]*select((((r_a % 6) == 5) && ((vh % 4) == 3)), 1f, select((((r_a % 6) == 5) && ((vh % 4) == 2)), 0f, select((((r_a % 6) == 5) && ((vh % 4) == 1)), 0f, select((((r_a % 6) == 5) && ((vh % 4) == 0)), 0f, select((((r_a % 6) == 4) && ((vh % 4) == 3)), -8f, select((((r_a % 6) == 4) && ((vh % 4) == 2)), 4f, select((((r_a % 6) == 4) && ((vh % 4) == 1)), -2f, select((((r_a % 6) == 4) && ((vh % 4) == 0)), 1f, select((((r_a % 6) == 3) && ((vh % 4) == 3)), 0.125f, select((((r_a % 6) == 3) && ((vh % 4) == 2)), 0.25f, select((((r_a % 6) == 3) && ((vh % 4) == 1)), 0.5f, select((((r_a % 6) == 3) && ((vh % 4) == 0)), 1f, select((((r_a % 6) == 2) && ((vh % 4) == 3)), 1f, select((((r_a % 6) == 2) && ((vh % 4) == 2)), 1f, select((((r_a % 6) == 2) && ((vh % 4) == 1)), 1f, select((((r_a % 6) == 2) && ((vh % 4) == 0)), 1f, select((((r_a % 6) == 1) && ((vh % 4) == 3)), -1f, select((((r_a % 6) == 1) && ((vh % 4) == 2)), 1f, select((((r_a % 6) == 1) && ((vh % 4) == 1)), -1f, select((((r_a % 6) == 1) && ((vh % 4) == 0)), 1f, select((((r_a % 6) == 0) && ((vh % 4) == 3)), 0f, select((((r_a % 6) == 0) && ((vh % 4) == 2)), 0f, select((((r_a % 6) == 0) && ((vh % 4) == 1)), 0f, select((((r_a % 6) == 0) && ((vh % 4) == 0)), 1f, 0f)))))))))))))))))))))))))*select((((r_a % 6) == 5) && ((vw % 4) == 3)), 1f, select((((r_a % 6) == 5) && ((vw % 4) == 2)), 0f, select((((r_a % 6) == 5) && ((vw % 4) == 1)), 0f, select((((r_a % 6) == 5) && ((vw % 4) == 0)), 0f, select((((r_a % 6) == 4) && ((vw % 4) == 3)), -8f, select((((r_a % 6) == 4) && ((vw % 4) == 2)), 4f, select((((r_a % 6) == 4) && ((vw % 4) == 1)), -2f, select((((r_a % 6) == 4) && ((vw % 4) == 0)), 1f, select((((r_a % 6) == 3) && ((vw % 4) == 3)), 0.125f, select((((r_a % 6) == 3) && ((vw % 4) == 2)), 0.25f, select((((r_a % 6) == 3) && ((vw % 4) == 1)), 0.5f, select((((r_a % 6) == 3) && ((vw % 4) == 0)), 1f, select((((r_a % 6) == 2) && ((vw % 4) == 3)), 1f, select((((r_a % 6) == 2) && ((vw % 4) == 2)), 1f, select((((r_a % 6) == 2) && ((vw % 4) == 1)), 1f, select((((r_a % 6) == 2) && ((vw % 4) == 0)), 1f, select((((r_a % 6) == 1) && ((vw % 4) == 3)), -1f, select((((r_a % 6) == 1) && ((vw % 4) == 2)), 1f, select((((r_a % 6) == 1) && ((vw % 4) == 1)), -1f, select((((r_a % 6) == 1) && ((vw % 4) == 0)), 1f, select((((r_a % 6) == 0) && ((vw % 4) == 3)), 0f, select((((r_a % 6) == 0) && ((vw % 4) == 2)), 0f, select((((r_a % 6) == 0) && ((vw % 4) == 1)), 0f, select((((r_a % 6) == 0) && ((vw % 4) == 0)), 1f, 0f))))))))))))))))))))))))))
              }
            }
          }
        }
      }
    }
  }
  for (h.inner, 0, 4) {
    for (w.inner, 0, 4) {
      output[(((((((blockIdx.x*128) + threadIdx.x)/16)*256) + (h.inner*64)) + ((((blockIdx.x*128) + threadIdx.x) % 16)*4)) + w.inner)] = inverse[(((((((((((blockIdx.x*128) + threadIdx.x)/16384)*4096) + ((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*4) + h.inner)/4)*256)) + (((((((blockIdx.x*128) + threadIdx.x) % 16)*4) + w.inner)/4)*16)) + (((((((blockIdx.x*128) + threadIdx.x) % 16384)/256) - (((blockIdx.x*8) + (threadIdx.x/16))/16))*((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*16) + (threadIdx.x % 16)) + 1) - (((blockIdx.x*128) + threadIdx.x) % 256)))*16)) + ((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*4) + h.inner) % 4)*4)) + ((((((blockIdx.x*128) + threadIdx.x) % 16)*4) + w.inner) % 4)) - ((threadIdx.x % 16)*16)) - ((((blockIdx.x*8) + (threadIdx.x/16)) % 16)*256))]
    }
  }
}

please be aware of in the output block

allocate inverse[float32 * ((((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*16) + (threadIdx.x % 16)) + 1) - (((blockIdx.x*128) + threadIdx.x) % 256))*4)*4)]
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 128

before this pr https://github.com/dmlc/tvm/pull/3368, it was like this:

// attr [U] storage_scope = "global"
allocate U[float32 * 294912]
// attr [V] storage_scope = "global"
allocate V[float32 * 1179648]
produce U {
  parallel (k, 0, 64) {
    for (c, 0, 128) {
      unrolled (eps, 0, 6) {
        unrolled (nu, 0, 6) {
          U[((((((eps*6) + nu)*64) + k)*128) + c)] = 0.000000f
          unrolled (r_kh, 0, 3) {
            unrolled (r_kw, 0, 3) {
              U[((((((eps*6) + nu)*64) + k)*128) + c)] = (U[((((((eps*6) + nu)*64) + k)*128) + c)] + ((W[((((((k*128) + c)*3) + r_kh)*3) + r_kw)]*select((((eps % 6) == 5) && ((r_kh % 3) == 2)), 1.000000f, select((((eps % 6) == 5) && ((r_kh % 3) == 1)), 0.000000f, select((((eps % 6) == 5) && ((r_kh % 3) == 0)), 0.000000f, select((((eps % 6) == 4) && ((r_kh % 3) == 2)), 0.166667f, select((((eps % 6) == 4) && ((r_kh % 3) == 1)), -0.083333f, select((((eps % 6) == 4) && ((r_kh % 3) == 0)), 0.041667f, select((((eps % 6) == 3) && ((r_kh % 3) == 2)), 0.166667f, select((((eps % 6) == 3) && ((r_kh % 3) == 1)), 0.083333f, select((((eps % 6) == 3) && ((r_kh % 3) == 0)), 0.041667f, select((((eps % 6) == 2) && ((r_kh % 3) == 2)), -0.166667f, select((((eps % 6) == 2) && ((r_kh % 3) == 1)), 0.166667f, select((((eps % 6) == 2) && ((r_kh % 3) == 0)), -0.166667f, select((((eps % 6) == 1) && ((r_kh % 3) == 2)), -0.166667f, select((((eps % 6) == 1) && ((r_kh % 3) == 1)), -0.166667f, select((((eps % 6) == 1) && ((r_kh % 3) == 0)), -0.166667f, select((((eps % 6) == 0) && ((r_kh % 3) == 2)), 0.000000f, select((((eps % 6) == 0) && ((r_kh % 3) == 1)), 0.000000f, select((((eps % 6) == 0) && ((r_kh % 3) == 0)), 0.250000f, 0.000000f)))))))))))))))))))*select((((nu % 6) == 5) && ((r_kw % 3) == 2)), 1.000000f, select((((nu % 6) == 5) && ((r_kw % 3) == 1)), 0.000000f, select((((nu % 6) == 5) && ((r_kw % 3) == 0)), 0.000000f, select((((nu % 6) == 4) && ((r_kw % 3) == 2)), 0.166667f, select((((nu % 6) == 4) && ((r_kw % 3) == 1)), -0.083333f, select((((nu % 6) == 4) && ((r_kw % 3) == 0)), 0.041667f, select((((nu % 6) == 3) && ((r_kw % 3) == 2)), 0.166667f, select((((nu % 6) == 3) && ((r_kw % 3) == 1)), 0.083333f, select((((nu % 6) == 3) && ((r_kw % 3) == 0)), 0.041667f, select((((nu % 6) == 2) && ((r_kw % 3) == 2)), -0.166667f, select((((nu % 6) == 2) && ((r_kw % 3) == 1)), 0.166667f, select((((nu % 6) == 2) && ((r_kw % 3) == 0)), -0.166667f, select((((nu % 6) == 1) && ((r_kw % 3) == 2)), -0.166667f, select((((nu % 6) == 1) && ((r_kw % 3) == 1)), -0.166667f, select((((nu % 6) == 1) && ((r_kw % 3) == 0)), -0.166667f, select((((nu % 6) == 0) && ((r_kw % 3) == 2)), 0.000000f, select((((nu % 6) == 0) && ((r_kw % 3) == 1)), 0.000000f, select((((nu % 6) == 0) && ((r_kw % 3) == 0)), 0.250000f, 0.000000f))))))))))))))))))))
            }
          }
        }
      }
    }
  }
}
produce V {
  parallel (b, 0, 256) {
    // attr [d.global] storage_scope = "global"
    allocate d.global[float32 * 36]
    for (c, 0, 128) {
      produce d.global {
        for (ax2, 0, 6) {
          for (ax3, 0, 6) {
            d.global[((ax2*6) + ax3)] = tvm_if_then_else((((((((((b + 0)/16) % 16)*4) + ax2) >= 1) && ((((((b + 0)/16) % 16)*4) + ax2) < 65)) && (((((b + 0) % 16)*4) + ax3) >= 1)) && (((((b + 0) % 16)*4) + ax3) < 65)), A[((((((((b/256)*128) + c)*64) + ((((b/16) % 16)*4) + ax2))*64) + (((b % 16)*4) + ax3)) + -65)], 0.000000f)
          }
        }
      }
      unrolled (eps, 0, 6) {
        unrolled (nu, 0, 6) {
          V[((((((eps*6) + nu)*256) + b)*128) + c)] = 0.000000f
          unrolled (r_eps, 0, 6) {
            unrolled (r_nu, 0, 6) {
              V[((((((eps*6) + nu)*256) + b)*128) + c)] = (V[((((((eps*6) + nu)*256) + b)*128) + c)] + ((d.global[((r_eps*6) + r_nu)]*select((((r_eps % 6) == 5) && ((eps % 6) == 5)), 1.000000f, select((((r_eps % 6) == 5) && ((eps % 6) == 4)), 0.000000f, select((((r_eps % 6) == 5) && ((eps % 6) == 3)), 0.000000f, select((((r_eps % 6) == 5) && ((eps % 6) == 2)), 0.000000f, select((((r_eps % 6) == 5) && ((eps % 6) == 1)), 0.000000f, select((((r_eps % 6) == 5) && ((eps % 6) == 0)), 0.000000f, select((((r_eps % 6) == 4) && ((eps % 6) == 5)), 0.000000f, select((((r_eps % 6) == 4) && ((eps % 6) == 4)), 1.000000f, select((((r_eps % 6) == 4) && ((eps % 6) == 3)), 1.000000f, select((((r_eps % 6) == 4) && ((eps % 6) == 2)), 1.000000f, select((((r_eps % 6) == 4) && ((eps % 6) == 1)), 1.000000f, select((((r_eps % 6) == 4) && ((eps % 6) == 0)), 1.000000f, select((((r_eps % 6) == 3) && ((eps % 6) == 5)), -5.000000f, select((((r_eps % 6) == 3) && ((eps % 6) == 4)), -2.000000f, select((((r_eps % 6) == 3) && ((eps % 6) == 3)), 2.000000f, select((((r_eps % 6) == 3) && ((eps % 6) == 2)), -1.000000f, select((((r_eps % 6) == 3) && ((eps % 6) == 1)), 1.000000f, select((((r_eps % 6) == 3) && ((eps % 6) == 0)), 0.000000f, select((((r_eps % 6) == 2) && ((eps % 6) == 5)), 0.000000f, select((((r_eps % 6) == 2) && ((eps % 6) == 4)), -1.000000f, select((((r_eps % 6) == 2) && ((eps % 6) == 3)), -1.000000f, select((((r_eps % 6) == 2) && ((eps % 6) == 2)), -4.000000f, select((((r_eps % 6) == 2) && ((eps % 6) == 1)), -4.000000f, select((((r_eps % 6) == 2) && ((eps % 6) == 0)), -5.000000f, select((((r_eps % 6) == 1) && ((eps % 6) == 5)), 4.000000f, select((((r_eps % 6) == 1) && ((eps % 6) == 4)), 2.000000f, select((((r_eps % 6) == 1) && ((eps % 6) == 3)), -2.000000f, select((((r_eps % 6) == 1) && ((eps % 6) == 2)), 4.000000f, select((((r_eps % 6) == 1) && ((eps % 6) == 1)), -4.000000f, select((((r_eps % 6) == 1) && ((eps % 6) == 0)), 0.000000f, select((((r_eps % 6) == 0) && ((eps % 6) == 5)), 0.000000f, select((((r_eps % 6) == 0) && ((eps % 6) == 4)), 0.000000f, select((((r_eps % 6) == 0) && ((eps % 6) == 3)), 0.000000f, select((((r_eps % 6) == 0) && ((eps % 6) == 2)), 0.000000f, select((((r_eps % 6) == 0) && ((eps % 6) == 1)), 0.000000f, select((((r_eps % 6) == 0) && ((eps % 6) == 0)), 4.000000f, 0.000000f)))))))))))))))))))))))))))))))))))))*select((((r_nu % 6) == 5) && ((nu % 6) == 5)), 1.000000f, select((((r_nu % 6) == 5) && ((nu % 6) == 4)), 0.000000f, select((((r_nu % 6) == 5) && ((nu % 6) == 3)), 0.000000f, select((((r_nu % 6) == 5) && ((nu % 6) == 2)), 0.000000f, select((((r_nu % 6) == 5) && ((nu % 6) == 1)), 0.000000f, select((((r_nu % 6) == 5) && ((nu % 6) == 0)), 0.000000f, select((((r_nu % 6) == 4) && ((nu % 6) == 5)), 0.000000f, select((((r_nu % 6) == 4) && ((nu % 6) == 4)), 1.000000f, select((((r_nu % 6) == 4) && ((nu % 6) == 3)), 1.000000f, select((((r_nu % 6) == 4) && ((nu % 6) == 2)), 1.000000f, select((((r_nu % 6) == 4) && ((nu % 6) == 1)), 1.000000f, select((((r_nu % 6) == 4) && ((nu % 6) == 0)), 1.000000f, select((((r_nu % 6) == 3) && ((nu % 6) == 5)), -5.000000f, select((((r_nu % 6) == 3) && ((nu % 6) == 4)), -2.000000f, select((((r_nu % 6) == 3) && ((nu % 6) == 3)), 2.000000f, select((((r_nu % 6) == 3) && ((nu % 6) == 2)), -1.000000f, select((((r_nu % 6) == 3) && ((nu % 6) == 1)), 1.000000f, select((((r_nu % 6) == 3) && ((nu % 6) == 0)), 0.000000f, select((((r_nu % 6) == 2) && ((nu % 6) == 5)), 0.000000f, select((((r_nu % 6) == 2) && ((nu % 6) == 4)), -1.000000f, select((((r_nu % 6) == 2) && ((nu % 6) == 3)), -1.000000f, select((((r_nu % 6) == 2) && ((nu % 6) == 2)), -4.000000f, select((((r_nu % 6) == 2) && ((nu % 6) == 1)), -4.000000f, select((((r_nu % 6) == 2) && ((nu % 6) == 0)), -5.000000f, select((((r_nu % 6) == 1) && ((nu % 6) == 5)), 4.000000f, select((((r_nu % 6) == 1) && ((nu % 6) == 4)), 2.000000f, select((((r_nu % 6) == 1) && ((nu % 6) == 3)), -2.000000f, select((((r_nu % 6) == 1) && ((nu % 6) == 2)), 4.000000f, select((((r_nu % 6) == 1) && ((nu % 6) == 1)), -4.000000f, select((((r_nu % 6) == 1) && ((nu % 6) == 0)), 0.000000f, select((((r_nu % 6) == 0) && ((nu % 6) == 5)), 0.000000f, select((((r_nu % 6) == 0) && ((nu % 6) == 4)), 0.000000f, select((((r_nu % 6) == 0) && ((nu % 6) == 3)), 0.000000f, select((((r_nu % 6) == 0) && ((nu % 6) == 2)), 0.000000f, select((((r_nu % 6) == 0) && ((nu % 6) == 1)), 0.000000f, select((((r_nu % 6) == 0) && ((nu % 6) == 0)), 4.000000f, 0.000000f))))))))))))))))))))))))))))))))))))))
            }
          }
        }
      }
    }
  }
}
produce output {
  parallel (k.outer, 0, 64) {
    // attr [M] storage_scope = "global"
    allocate M[float32 * 9216]
    // attr [M.global] storage_scope = "global"
    allocate M.global[float32 * 36]
    // attr [Y] storage_scope = "global"
    allocate Y[float32 * 16]
    produce M {
      for (eps, 0, 6) {
        for (nu, 0, 6) {
          for (b.outer, 0, 256) {
            M[((((eps*6) + nu)*256) + b.outer)] = 0.000000f
            for (c.outer, 0, 128) {
              M[((((eps*6) + nu)*256) + b.outer)] = (M[((((eps*6) + nu)*256) + b.outer)] + (U[((((((eps*6) + nu)*64) + k.outer)*128) + c.outer)]*V[((((((eps*6) + nu)*256) + b.outer)*128) + c.outer)]))
            }
          }
        }
      }
    }
    for (h.outer, 0, 16) {
      for (w.outer, 0, 16) {
        produce M.global {
          for (ax0, 0, 6) {
            for (ax1, 0, 6) {
              M.global[((ax0*6) + ax1)] = M[((((ax0*6) + ax1)*256) + ((h.outer*16) + w.outer))]
            }
          }
        }
        produce Y {
          unrolled (vh, 0, 4) {
            unrolled (vw, 0, 4) {
              Y[((vh*4) + vw)] = 0.000000f
              unrolled (r_eps, 0, 6) {
                unrolled (r_nu, 0, 6) {
                  Y[((vh*4) + vw)] = (Y[((vh*4) + vw)] + ((M.global[((r_eps*6) + r_nu)]*select((((r_eps % 6) == 5) && ((vh % 4) == 3)), 1.000000f, select((((r_eps % 6) == 5) && ((vh % 4) == 2)), 0.000000f, select((((r_eps % 6) == 5) && ((vh % 4) == 1)), 0.000000f, select((((r_eps % 6) == 5) && ((vh % 4) == 0)), 0.000000f, select((((r_eps % 6) == 4) && ((vh % 4) == 3)), -8.000000f, select((((r_eps % 6) == 4) && ((vh % 4) == 2)), 4.000000f, select((((r_eps % 6) == 4) && ((vh % 4) == 1)), -2.000000f, select((((r_eps % 6) == 4) && ((vh % 4) == 0)), 1.000000f, select((((r_eps % 6) == 3) && ((vh % 4) == 3)), 8.000000f, select((((r_eps % 6) == 3) && ((vh % 4) == 2)), 4.000000f, select((((r_eps % 6) == 3) && ((vh % 4) == 1)), 2.000000f, select((((r_eps % 6) == 3) && ((vh % 4) == 0)), 1.000000f, select((((r_eps % 6) == 2) && ((vh % 4) == 3)), -1.000000f, select((((r_eps % 6) == 2) && ((vh % 4) == 2)), 1.000000f, select((((r_eps % 6) == 2) && ((vh % 4) == 1)), -1.000000f, select((((r_eps % 6) == 2) && ((vh % 4) == 0)), 1.000000f, select((((r_eps % 6) == 1) && ((vh % 4) == 3)), 1.000000f, select((((r_eps % 6) == 1) && ((vh % 4) == 2)), 1.000000f, select((((r_eps % 6) == 1) && ((vh % 4) == 1)), 1.000000f, select((((r_eps % 6) == 1) && ((vh % 4) == 0)), 1.000000f, select((((r_eps % 6) == 0) && ((vh % 4) == 3)), 0.000000f, select((((r_eps % 6) == 0) && ((vh % 4) == 2)), 0.000000f, select((((r_eps % 6) == 0) && ((vh % 4) == 1)), 0.000000f, select((((r_eps % 6) == 0) && ((vh % 4) == 0)), 1.000000f, 0.000000f)))))))))))))))))))))))))*select((((r_nu % 6) == 5) && ((vw % 4) == 3)), 1.000000f, select((((r_nu % 6) == 5) && ((vw % 4) == 2)), 0.000000f, select((((r_nu % 6) == 5) && ((vw % 4) == 1)), 0.000000f, select((((r_nu % 6) == 5) && ((vw % 4) == 0)), 0.000000f, select((((r_nu % 6) == 4) && ((vw % 4) == 3)), -8.000000f, select((((r_nu % 6) == 4) && ((vw % 4) == 2)), 4.000000f, select((((r_nu % 6) == 4) && ((vw % 4) == 1)), -2.000000f, select((((r_nu % 6) == 4) && ((vw % 4) == 0)), 1.000000f, select((((r_nu % 6) == 3) && ((vw % 4) == 3)), 8.000000f, select((((r_nu % 6) == 3) && ((vw % 4) == 2)), 4.000000f, select((((r_nu % 6) == 3) && ((vw % 4) == 1)), 2.000000f, select((((r_nu % 6) == 3) && ((vw % 4) == 0)), 1.000000f, select((((r_nu % 6) == 2) && ((vw % 4) == 3)), -1.000000f, select((((r_nu % 6) == 2) && ((vw % 4) == 2)), 1.000000f, select((((r_nu % 6) == 2) && ((vw % 4) == 1)), -1.000000f, select((((r_nu % 6) == 2) && ((vw % 4) == 0)), 1.000000f, select((((r_nu % 6) == 1) && ((vw % 4) == 3)), 1.000000f, select((((r_nu % 6) == 1) && ((vw % 4) == 2)), 1.000000f, select((((r_nu % 6) == 1) && ((vw % 4) == 1)), 1.000000f, select((((r_nu % 6) == 1) && ((vw % 4) == 0)), 1.000000f, select((((r_nu % 6) == 0) && ((vw % 4) == 3)), 0.000000f, select((((r_nu % 6) == 0) && ((vw % 4) == 2)), 0.000000f, select((((r_nu % 6) == 0) && ((vw % 4) == 1)), 0.000000f, select((((r_nu % 6) == 0) && ((vw % 4) == 0)), 1.000000f, 0.000000f))))))))))))))))))))))))))
                }
              }
            }
          }
        }
        for (h.inner, 0, 4) {
          for (w.inner, 0, 4) {
            output[((((k.outer*64) + (h.inner + (h.outer*4)))*64) + (w.inner + (w.outer*4)))] = Y[(((((((((h.inner/4) + h.outer)*16) + ((w.inner/4) + w.outer)) - ((h.outer*16) + w.outer))*4) + (h.inner % 4))*4) + (w.inner % 4))]
          }
        }
      }
    }
  }
}

Something probably went wrong for StorageRewrite after simplification.

Seems was due to simplifier failed to simplify the index calculation, and then the alloca get lifted out. We should do two things:

  • Add a clear error message during rewriting to assert that memory being constant(not dependent on the internal thread variables
  • Add more robust simplification to simplify the indices, by populating more bounds context so that the simplifier know that the bounds are non-negative. related https://github.com/dmlc/tvm/pull/3981

@zhiics it would be great if you can help contributing to these :slight_smile:

@tqchen Thanks for the information. Sure. I can give it a try when I get cycles in the next couple of days.

@tqchen I am trying to work on it. I realized that the same shape could work by change the tile size from 4 to 8 here:

The difference of the bounds in StorageFlatten is that one could be simplified to constants but the other one could not here:

I tried to simplify r->extent since blockIdx.x and threadIdx.x are bounded in the bounded_analyzer_ , but it didn’t help. Should we also bind the bounds of an Realize node in the IRVisitorWithAnalyzer?

If so, I am not sure how we can bind each of the Range in the the bounds to some vars since Realize doesn’t have itervar. Could you please advise if this is the correct way or otherwise what should be bound to help the simplifier? Thanks.