I reduced it to the minimum number of needed passes:
binds, arg_list = get_binds(args, binds)
cfg = current_build_config()
if isinstance(sch, schedule.Schedule):
stmt = form_body(sch)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.StorageRewrite(stmt)
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
The stmt printed after StorageRewrite is like the following:
allocate kernel_pack[float32 * 294912]
// attr [data_pack] storage_scope = "global"
allocate data_pack[float32 * 1179648]
// attr [bgemm] storage_scope = "global"
allocate bgemm[float32 * 589824]
produce kernel_pack {
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 64
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 128
unrolled (eps, 0, 6) {
unrolled (nu, 0, 6) {
kernel_pack[(((eps*49152) + (nu*8192)) + ((blockIdx.x*128) + threadIdx.x))] = 0f
unrolled (r_kh, 0, 3) {
unrolled (r_kw, 0, 3) {
kernel_pack[(((eps*49152) + (nu*8192)) + ((blockIdx.x*128) + threadIdx.x))] = (kernel_pack[(((eps*49152) + (nu*8192)) + ((blockIdx.x*128) + threadIdx.x))] + ((W[(((((((blockIdx.x*128) + threadIdx.x) % 64)*1152) + ((((blockIdx.x*128) + threadIdx.x)/64)*9)) + (r_kh*3)) + r_kw)]*select((((eps % 6) == 5) && ((r_kh % 3) == 2)), 1f, select((((eps % 6) == 5) && ((r_kh % 3) == 1)), 0f, select((((eps % 6) == 5) && ((r_kh % 3) == 0)), 0f, select((((eps % 6) == 4) && ((r_kh % 3) == 2)), 0.266667f, select((((eps % 6) == 4) && ((r_kh % 3) == 1)), -0.133333f, select((((eps % 6) == 4) && ((r_kh % 3) == 0)), 0.0666667f, select((((eps % 6) == 3) && ((r_kh % 3) == 2)), -0.266667f, select((((eps % 6) == 3) && ((r_kh % 3) == 1)), -0.533333f, select((((eps % 6) == 3) && ((r_kh % 3) == 0)), -1.06667f, select((((eps % 6) == 2) && ((r_kh % 3) == 2)), 0.333333f, select((((eps % 6) == 2) && ((r_kh % 3) == 1)), 0.333333f, select((((eps % 6) == 2) && ((r_kh % 3) == 0)), 0.333333f, select((((eps % 6) == 1) && ((r_kh % 3) == 2)), -0.333333f, select((((eps % 6) == 1) && ((r_kh % 3) == 1)), 0.333333f, select((((eps % 6) == 1) && ((r_kh % 3) == 0)), -0.333333f, select((((eps % 6) == 0) && ((r_kh % 3) == 2)), 0f, select((((eps % 6) == 0) && ((r_kh % 3) == 1)), 0f, select((((eps % 6) == 0) && ((r_kh % 3) == 0)), 1f, 0f)))))))))))))))))))*select((((nu % 6) == 5) && ((r_kw % 3) == 2)), 1f, select((((nu % 6) == 5) && ((r_kw % 3) == 1)), 0f, select((((nu % 6) == 5) && ((r_kw % 3) == 0)), 0f, select((((nu % 6) == 4) && ((r_kw % 3) == 2)), 0.266667f, select((((nu % 6) == 4) && ((r_kw % 3) == 1)), -0.133333f, select((((nu % 6) == 4) && ((r_kw % 3) == 0)), 0.0666667f, select((((nu % 6) == 3) && ((r_kw % 3) == 2)), -0.266667f, select((((nu % 6) == 3) && ((r_kw % 3) == 1)), -0.533333f, select((((nu % 6) == 3) && ((r_kw % 3) == 0)), -1.06667f, select((((nu % 6) == 2) && ((r_kw % 3) == 2)), 0.333333f, select((((nu % 6) == 2) && ((r_kw % 3) == 1)), 0.333333f, select((((nu % 6) == 2) && ((r_kw % 3) == 0)), 0.333333f, select((((nu % 6) == 1) && ((r_kw % 3) == 2)), -0.333333f, select((((nu % 6) == 1) && ((r_kw % 3) == 1)), 0.333333f, select((((nu % 6) == 1) && ((r_kw % 3) == 0)), -0.333333f, select((((nu % 6) == 0) && ((r_kw % 3) == 2)), 0f, select((((nu % 6) == 0) && ((r_kw % 3) == 1)), 0f, select((((nu % 6) == 0) && ((r_kw % 3) == 0)), 1f, 0f))))))))))))))))))))
}
}
}
}
}
produce data_pack {
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 256
// attr [d] storage_scope = "local"
allocate d[float32 * 36]
// attr [data_pack.local] storage_scope = "local"
allocate data_pack.local[float32 * 36]
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 128
produce d {
for (eps, 0, 6) {
for (nu, 0, 6) {
d[((eps*6) + nu)] = tvm_if_then_else((((((((((((blockIdx.x*128) + threadIdx.x) % 256)/16) % 16)*4) + eps) >= 1) && ((((((((blockIdx.x*128) + threadIdx.x) % 256)/16) % 16)*4) + eps) < 65)) && (((((((blockIdx.x*128) + threadIdx.x) % 256) % 16)*4) + nu) >= 1)) && (((((((blockIdx.x*128) + threadIdx.x) % 256) % 16)*4) + nu) < 65)), A[((((((((((blockIdx.x*128) + threadIdx.x) % 256)/256)*524288) + ((((blockIdx.x*128) + threadIdx.x)/16)*256)) + (eps*64)) + ((((blockIdx.x*128) + threadIdx.x) % 16)*4)) + nu) - 65)], 0f)
}
}
}
produce data_pack.local {
unrolled (eps.c, 0, 6) {
unrolled (nu.c, 0, 6) {
data_pack.local[((eps.c*6) + nu.c)] = 0f
unrolled (r_a, 0, 6) {
unrolled (r_a, 0, 6) {
data_pack.local[((eps.c*6) + nu.c)] = (data_pack.local[((eps.c*6) + nu.c)] + ((d[((r_a*6) + r_a)]*select((((r_a % 6) == 5) && ((eps.c % 6) == 5)), 1f, select((((r_a % 6) == 5) && ((eps.c % 6) == 4)), 0f, select((((r_a % 6) == 5) && ((eps.c % 6) == 3)), 0f, select((((r_a % 6) == 5) && ((eps.c % 6) == 2)), 0f, select((((r_a % 6) == 5) && ((eps.c % 6) == 1)), 0f, select((((r_a % 6) == 5) && ((eps.c % 6) == 0)), 0f, select((((r_a % 6) == 4) && ((eps.c % 6) == 5)), 1.5f, select((((r_a % 6) == 4) && ((eps.c % 6) == 4)), 1f, select((((r_a % 6) == 4) && ((eps.c % 6) == 3)), 1f, select((((r_a % 6) == 4) && ((eps.c % 6) == 2)), 1f, select((((r_a % 6) == 4) && ((eps.c % 6) == 1)), 1f, select((((r_a % 6) == 4) && ((eps.c % 6) == 0)), 1f, select((((r_a % 6) == 3) && ((eps.c % 6) == 5)), -2f, select((((r_a % 6) == 3) && ((eps.c % 6) == 4)), -0.5f, select((((r_a % 6) == 3) && ((eps.c % 6) == 3)), 2f, select((((r_a % 6) == 3) && ((eps.c % 6) == 2)), 2.5f, select((((r_a % 6) == 3) && ((eps.c % 6) == 1)), 0.5f, select((((r_a % 6) == 3) && ((eps.c % 6) == 0)), 1.5f, select((((r_a % 6) == 2) && ((eps.c % 6) == 5)), -1.5f, select((((r_a % 6) == 2) && ((eps.c % 6) == 4)), -1f, select((((r_a % 6) == 2) && ((eps.c % 6) == 3)), -1f, select((((r_a % 6) == 2) && ((eps.c % 6) == 2)), 0.5f, select((((r_a % 6) == 2) && ((eps.c % 6) == 1)), -2.5f, select((((r_a % 6) == 2) && ((eps.c % 6) == 0)), -2f, select((((r_a % 6) == 1) && ((eps.c % 6) == 5)), 1f, select((((r_a % 6) == 1) && ((eps.c % 6) == 4)), 0.5f, select((((r_a % 6) == 1) && ((eps.c % 6) == 3)), -2f, select((((r_a % 6) == 1) && ((eps.c % 6) == 2)), -1f, select((((r_a % 6) == 1) && ((eps.c % 6) == 1)), 1f, select((((r_a % 6) == 1) && ((eps.c % 6) == 0)), -1.5f, select((((r_a % 6) == 0) && ((eps.c % 6) == 5)), 0f, select((((r_a % 6) == 0) && ((eps.c % 6) == 4)), 0f, select((((r_a % 6) == 0) && ((eps.c % 6) == 3)), 0f, select((((r_a % 6) == 0) && ((eps.c % 6) == 2)), 0f, select((((r_a % 6) == 0) && ((eps.c % 6) == 1)), 0f, select((((r_a % 6) == 0) && ((eps.c % 6) == 0)), 1f, 0f)))))))))))))))))))))))))))))))))))))*select((((r_a % 6) == 5) && ((nu.c % 6) == 5)), 1f, select((((r_a % 6) == 5) && ((nu.c % 6) == 4)), 0f, select((((r_a % 6) == 5) && ((nu.c % 6) == 3)), 0f, select((((r_a % 6) == 5) && ((nu.c % 6) == 2)), 0f, select((((r_a % 6) == 5) && ((nu.c % 6) == 1)), 0f, select((((r_a % 6) == 5) && ((nu.c % 6) == 0)), 0f, select((((r_a % 6) == 4) && ((nu.c % 6) == 5)), 1.5f, select((((r_a % 6) == 4) && ((nu.c % 6) == 4)), 1f, select((((r_a % 6) == 4) && ((nu.c % 6) == 3)), 1f, select((((r_a % 6) == 4) && ((nu.c % 6) == 2)), 1f, select((((r_a % 6) == 4) && ((nu.c % 6) == 1)), 1f, select((((r_a % 6) == 4) && ((nu.c % 6) == 0)), 1f, select((((r_a % 6) == 3) && ((nu.c % 6) == 5)), -2f, select((((r_a % 6) == 3) && ((nu.c % 6) == 4)), -0.5f, select((((r_a % 6) == 3) && ((nu.c % 6) == 3)), 2f, select((((r_a % 6) == 3) && ((nu.c % 6) == 2)), 2.5f, select((((r_a % 6) == 3) && ((nu.c % 6) == 1)), 0.5f, select((((r_a % 6) == 3) && ((nu.c % 6) == 0)), 1.5f, select((((r_a % 6) == 2) && ((nu.c % 6) == 5)), -1.5f, select((((r_a % 6) == 2) && ((nu.c % 6) == 4)), -1f, select((((r_a % 6) == 2) && ((nu.c % 6) == 3)), -1f, select((((r_a % 6) == 2) && ((nu.c % 6) == 2)), 0.5f, select((((r_a % 6) == 2) && ((nu.c % 6) == 1)), -2.5f, select((((r_a % 6) == 2) && ((nu.c % 6) == 0)), -2f, select((((r_a % 6) == 1) && ((nu.c % 6) == 5)), 1f, select((((r_a % 6) == 1) && ((nu.c % 6) == 4)), 0.5f, select((((r_a % 6) == 1) && ((nu.c % 6) == 3)), -2f, select((((r_a % 6) == 1) && ((nu.c % 6) == 2)), -1f, select((((r_a % 6) == 1) && ((nu.c % 6) == 1)), 1f, select((((r_a % 6) == 1) && ((nu.c % 6) == 0)), -1.5f, select((((r_a % 6) == 0) && ((nu.c % 6) == 5)), 0f, select((((r_a % 6) == 0) && ((nu.c % 6) == 4)), 0f, select((((r_a % 6) == 0) && ((nu.c % 6) == 3)), 0f, select((((r_a % 6) == 0) && ((nu.c % 6) == 2)), 0f, select((((r_a % 6) == 0) && ((nu.c % 6) == 1)), 0f, select((((r_a % 6) == 0) && ((nu.c % 6) == 0)), 1f, 0f))))))))))))))))))))))))))))))))))))))
}
}
}
}
}
for (eps, 0, 6) {
for (nu, 0, 6) {
data_pack[(((eps*196608) + (nu*32768)) + ((blockIdx.x*128) + threadIdx.x))] = data_pack.local[((eps*6) + nu)]
}
}
}
produce bgemm {
// attr [iter_var(eps.nu.fused.outer, )] pragma_auto_unroll_max_step = 0
// attr [iter_var(eps.nu.fused.outer, )] pragma_unroll_explicit = 0
// attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = 36
// attr [bgemm.local] storage_scope = "local"
allocate bgemm.local[float32 * 1]
// attr [kernel_pack.shared] storage_scope = "shared"
allocate kernel_pack.shared[float32 * 1]
// attr [data_pack.shared] storage_scope = "shared"
allocate data_pack.shared[float32 * 1]
// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 64
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 256
// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1
produce bgemm.local {
bgemm.local[0] = 0f
for (ci.outer, 0, 128) {
produce kernel_pack.shared {
// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1
kernel_pack.shared[0] = kernel_pack[(((blockIdx.z*8192) + (ci.outer*64)) + blockIdx.y)]
}
produce data_pack.shared {
// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1
data_pack.shared[0] = data_pack[(((blockIdx.z*32768) + (ci.outer*256)) + blockIdx.x)]
}
bgemm.local[0] = (bgemm.local[0] + (kernel_pack.shared[0]*data_pack.shared[0]))
}
}
bgemm[((((((0*16384) + (blockIdx.z*16384)) + (0*256)) + (blockIdx.y*256)) + 0) + blockIdx.x)] = bgemm.local[(((((((0 + blockIdx.z)/6) + 0) + 0) + ((0 + blockIdx.z) % 6)) - (blockIdx.z/6)) - (blockIdx.z % 6))]
}
produce output {
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 128
// attr [inverse] storage_scope = "local"
allocate inverse[float32 * ((((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*16) + (threadIdx.x % 16)) + 1) - (((blockIdx.x*128) + threadIdx.x) % 256))*4)*4)]
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 128
produce inverse {
for (p, 0, ((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*16) + (threadIdx.x % 16)) + 1) - (((blockIdx.x*128) + threadIdx.x) % 256))) {
unrolled (vh, 0, 4) {
unrolled (vw, 0, 4) {
inverse[(((p*16) + (vh*4)) + vw)] = 0f
unrolled (r_a, 0, 6) {
unrolled (r_a, 0, 6) {
if (likely(((p + (((((blockIdx.x*8) + (threadIdx.x/16)) % 16)*16) + (threadIdx.x % 16))) < 256))) {
inverse[(((p*16) + (vh*4)) + vw)] = (inverse[(((p*16) + (vh*4)) + vw)] + ((bgemm[((((p + threadIdx.x) + (r_a*98304)) + (r_a*16384)) + ((blockIdx.x*8)*16))]*select((((r_a % 6) == 5) && ((vh % 4) == 3)), 1f, select((((r_a % 6) == 5) && ((vh % 4) == 2)), 0f, select((((r_a % 6) == 5) && ((vh % 4) == 1)), 0f, select((((r_a % 6) == 5) && ((vh % 4) == 0)), 0f, select((((r_a % 6) == 4) && ((vh % 4) == 3)), -8f, select((((r_a % 6) == 4) && ((vh % 4) == 2)), 4f, select((((r_a % 6) == 4) && ((vh % 4) == 1)), -2f, select((((r_a % 6) == 4) && ((vh % 4) == 0)), 1f, select((((r_a % 6) == 3) && ((vh % 4) == 3)), 0.125f, select((((r_a % 6) == 3) && ((vh % 4) == 2)), 0.25f, select((((r_a % 6) == 3) && ((vh % 4) == 1)), 0.5f, select((((r_a % 6) == 3) && ((vh % 4) == 0)), 1f, select((((r_a % 6) == 2) && ((vh % 4) == 3)), 1f, select((((r_a % 6) == 2) && ((vh % 4) == 2)), 1f, select((((r_a % 6) == 2) && ((vh % 4) == 1)), 1f, select((((r_a % 6) == 2) && ((vh % 4) == 0)), 1f, select((((r_a % 6) == 1) && ((vh % 4) == 3)), -1f, select((((r_a % 6) == 1) && ((vh % 4) == 2)), 1f, select((((r_a % 6) == 1) && ((vh % 4) == 1)), -1f, select((((r_a % 6) == 1) && ((vh % 4) == 0)), 1f, select((((r_a % 6) == 0) && ((vh % 4) == 3)), 0f, select((((r_a % 6) == 0) && ((vh % 4) == 2)), 0f, select((((r_a % 6) == 0) && ((vh % 4) == 1)), 0f, select((((r_a % 6) == 0) && ((vh % 4) == 0)), 1f, 0f)))))))))))))))))))))))))*select((((r_a % 6) == 5) && ((vw % 4) == 3)), 1f, select((((r_a % 6) == 5) && ((vw % 4) == 2)), 0f, select((((r_a % 6) == 5) && ((vw % 4) == 1)), 0f, select((((r_a % 6) == 5) && ((vw % 4) == 0)), 0f, select((((r_a % 6) == 4) && ((vw % 4) == 3)), -8f, select((((r_a % 6) == 4) && ((vw % 4) == 2)), 4f, select((((r_a % 6) == 4) && ((vw % 4) == 1)), -2f, select((((r_a % 6) == 4) && ((vw % 4) == 0)), 1f, select((((r_a % 6) == 3) && ((vw % 4) == 3)), 0.125f, select((((r_a % 6) == 3) && ((vw % 4) == 2)), 0.25f, select((((r_a % 6) == 3) && ((vw % 4) == 1)), 0.5f, select((((r_a % 6) == 3) && ((vw % 4) == 0)), 1f, select((((r_a % 6) == 2) && ((vw % 4) == 3)), 1f, select((((r_a % 6) == 2) && ((vw % 4) == 2)), 1f, select((((r_a % 6) == 2) && ((vw % 4) == 1)), 1f, select((((r_a % 6) == 2) && ((vw % 4) == 0)), 1f, select((((r_a % 6) == 1) && ((vw % 4) == 3)), -1f, select((((r_a % 6) == 1) && ((vw % 4) == 2)), 1f, select((((r_a % 6) == 1) && ((vw % 4) == 1)), -1f, select((((r_a % 6) == 1) && ((vw % 4) == 0)), 1f, select((((r_a % 6) == 0) && ((vw % 4) == 3)), 0f, select((((r_a % 6) == 0) && ((vw % 4) == 2)), 0f, select((((r_a % 6) == 0) && ((vw % 4) == 1)), 0f, select((((r_a % 6) == 0) && ((vw % 4) == 0)), 1f, 0f))))))))))))))))))))))))))
}
}
}
}
}
}
}
for (h.inner, 0, 4) {
for (w.inner, 0, 4) {
output[(((((((blockIdx.x*128) + threadIdx.x)/16)*256) + (h.inner*64)) + ((((blockIdx.x*128) + threadIdx.x) % 16)*4)) + w.inner)] = inverse[(((((((((((blockIdx.x*128) + threadIdx.x)/16384)*4096) + ((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*4) + h.inner)/4)*256)) + (((((((blockIdx.x*128) + threadIdx.x) % 16)*4) + w.inner)/4)*16)) + (((((((blockIdx.x*128) + threadIdx.x) % 16384)/256) - (((blockIdx.x*8) + (threadIdx.x/16))/16))*((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*16) + (threadIdx.x % 16)) + 1) - (((blockIdx.x*128) + threadIdx.x) % 256)))*16)) + ((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*4) + h.inner) % 4)*4)) + ((((((blockIdx.x*128) + threadIdx.x) % 16)*4) + w.inner) % 4)) - ((threadIdx.x % 16)*16)) - ((((blockIdx.x*8) + (threadIdx.x/16)) % 16)*256))]
}
}
}
please be aware of in the output block
allocate inverse[float32 * ((((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*16) + (threadIdx.x % 16)) + 1) - (((blockIdx.x*128) + threadIdx.x) % 256))*4)*4)]
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 128
before this pr https://github.com/dmlc/tvm/pull/3368, it was like this:
// attr [U] storage_scope = "global"
allocate U[float32 * 294912]
// attr [V] storage_scope = "global"
allocate V[float32 * 1179648]
produce U {
parallel (k, 0, 64) {
for (c, 0, 128) {
unrolled (eps, 0, 6) {
unrolled (nu, 0, 6) {
U[((((((eps*6) + nu)*64) + k)*128) + c)] = 0.000000f
unrolled (r_kh, 0, 3) {
unrolled (r_kw, 0, 3) {
U[((((((eps*6) + nu)*64) + k)*128) + c)] = (U[((((((eps*6) + nu)*64) + k)*128) + c)] + ((W[((((((k*128) + c)*3) + r_kh)*3) + r_kw)]*select((((eps % 6) == 5) && ((r_kh % 3) == 2)), 1.000000f, select((((eps % 6) == 5) && ((r_kh % 3) == 1)), 0.000000f, select((((eps % 6) == 5) && ((r_kh % 3) == 0)), 0.000000f, select((((eps % 6) == 4) && ((r_kh % 3) == 2)), 0.166667f, select((((eps % 6) == 4) && ((r_kh % 3) == 1)), -0.083333f, select((((eps % 6) == 4) && ((r_kh % 3) == 0)), 0.041667f, select((((eps % 6) == 3) && ((r_kh % 3) == 2)), 0.166667f, select((((eps % 6) == 3) && ((r_kh % 3) == 1)), 0.083333f, select((((eps % 6) == 3) && ((r_kh % 3) == 0)), 0.041667f, select((((eps % 6) == 2) && ((r_kh % 3) == 2)), -0.166667f, select((((eps % 6) == 2) && ((r_kh % 3) == 1)), 0.166667f, select((((eps % 6) == 2) && ((r_kh % 3) == 0)), -0.166667f, select((((eps % 6) == 1) && ((r_kh % 3) == 2)), -0.166667f, select((((eps % 6) == 1) && ((r_kh % 3) == 1)), -0.166667f, select((((eps % 6) == 1) && ((r_kh % 3) == 0)), -0.166667f, select((((eps % 6) == 0) && ((r_kh % 3) == 2)), 0.000000f, select((((eps % 6) == 0) && ((r_kh % 3) == 1)), 0.000000f, select((((eps % 6) == 0) && ((r_kh % 3) == 0)), 0.250000f, 0.000000f)))))))))))))))))))*select((((nu % 6) == 5) && ((r_kw % 3) == 2)), 1.000000f, select((((nu % 6) == 5) && ((r_kw % 3) == 1)), 0.000000f, select((((nu % 6) == 5) && ((r_kw % 3) == 0)), 0.000000f, select((((nu % 6) == 4) && ((r_kw % 3) == 2)), 0.166667f, select((((nu % 6) == 4) && ((r_kw % 3) == 1)), -0.083333f, select((((nu % 6) == 4) && ((r_kw % 3) == 0)), 0.041667f, select((((nu % 6) == 3) && ((r_kw % 3) == 2)), 0.166667f, select((((nu % 6) == 3) && ((r_kw % 3) == 1)), 0.083333f, select((((nu % 6) == 3) && ((r_kw % 3) == 0)), 0.041667f, select((((nu % 6) == 2) && ((r_kw % 3) == 2)), -0.166667f, select((((nu % 6) == 2) && ((r_kw % 3) == 1)), 0.166667f, select((((nu % 6) == 2) && ((r_kw % 3) == 0)), -0.166667f, select((((nu % 6) == 1) && ((r_kw % 3) == 2)), -0.166667f, select((((nu % 6) == 1) && ((r_kw % 3) == 1)), -0.166667f, select((((nu % 6) == 1) && ((r_kw % 3) == 0)), -0.166667f, select((((nu % 6) == 0) && ((r_kw % 3) == 2)), 0.000000f, select((((nu % 6) == 0) && ((r_kw % 3) == 1)), 0.000000f, select((((nu % 6) == 0) && ((r_kw % 3) == 0)), 0.250000f, 0.000000f))))))))))))))))))))
}
}
}
}
}
}
}
produce V {
parallel (b, 0, 256) {
// attr [d.global] storage_scope = "global"
allocate d.global[float32 * 36]
for (c, 0, 128) {
produce d.global {
for (ax2, 0, 6) {
for (ax3, 0, 6) {
d.global[((ax2*6) + ax3)] = tvm_if_then_else((((((((((b + 0)/16) % 16)*4) + ax2) >= 1) && ((((((b + 0)/16) % 16)*4) + ax2) < 65)) && (((((b + 0) % 16)*4) + ax3) >= 1)) && (((((b + 0) % 16)*4) + ax3) < 65)), A[((((((((b/256)*128) + c)*64) + ((((b/16) % 16)*4) + ax2))*64) + (((b % 16)*4) + ax3)) + -65)], 0.000000f)
}
}
}
unrolled (eps, 0, 6) {
unrolled (nu, 0, 6) {
V[((((((eps*6) + nu)*256) + b)*128) + c)] = 0.000000f
unrolled (r_eps, 0, 6) {
unrolled (r_nu, 0, 6) {
V[((((((eps*6) + nu)*256) + b)*128) + c)] = (V[((((((eps*6) + nu)*256) + b)*128) + c)] + ((d.global[((r_eps*6) + r_nu)]*select((((r_eps % 6) == 5) && ((eps % 6) == 5)), 1.000000f, select((((r_eps % 6) == 5) && ((eps % 6) == 4)), 0.000000f, select((((r_eps % 6) == 5) && ((eps % 6) == 3)), 0.000000f, select((((r_eps % 6) == 5) && ((eps % 6) == 2)), 0.000000f, select((((r_eps % 6) == 5) && ((eps % 6) == 1)), 0.000000f, select((((r_eps % 6) == 5) && ((eps % 6) == 0)), 0.000000f, select((((r_eps % 6) == 4) && ((eps % 6) == 5)), 0.000000f, select((((r_eps % 6) == 4) && ((eps % 6) == 4)), 1.000000f, select((((r_eps % 6) == 4) && ((eps % 6) == 3)), 1.000000f, select((((r_eps % 6) == 4) && ((eps % 6) == 2)), 1.000000f, select((((r_eps % 6) == 4) && ((eps % 6) == 1)), 1.000000f, select((((r_eps % 6) == 4) && ((eps % 6) == 0)), 1.000000f, select((((r_eps % 6) == 3) && ((eps % 6) == 5)), -5.000000f, select((((r_eps % 6) == 3) && ((eps % 6) == 4)), -2.000000f, select((((r_eps % 6) == 3) && ((eps % 6) == 3)), 2.000000f, select((((r_eps % 6) == 3) && ((eps % 6) == 2)), -1.000000f, select((((r_eps % 6) == 3) && ((eps % 6) == 1)), 1.000000f, select((((r_eps % 6) == 3) && ((eps % 6) == 0)), 0.000000f, select((((r_eps % 6) == 2) && ((eps % 6) == 5)), 0.000000f, select((((r_eps % 6) == 2) && ((eps % 6) == 4)), -1.000000f, select((((r_eps % 6) == 2) && ((eps % 6) == 3)), -1.000000f, select((((r_eps % 6) == 2) && ((eps % 6) == 2)), -4.000000f, select((((r_eps % 6) == 2) && ((eps % 6) == 1)), -4.000000f, select((((r_eps % 6) == 2) && ((eps % 6) == 0)), -5.000000f, select((((r_eps % 6) == 1) && ((eps % 6) == 5)), 4.000000f, select((((r_eps % 6) == 1) && ((eps % 6) == 4)), 2.000000f, select((((r_eps % 6) == 1) && ((eps % 6) == 3)), -2.000000f, select((((r_eps % 6) == 1) && ((eps % 6) == 2)), 4.000000f, select((((r_eps % 6) == 1) && ((eps % 6) == 1)), -4.000000f, select((((r_eps % 6) == 1) && ((eps % 6) == 0)), 0.000000f, select((((r_eps % 6) == 0) && ((eps % 6) == 5)), 0.000000f, select((((r_eps % 6) == 0) && ((eps % 6) == 4)), 0.000000f, select((((r_eps % 6) == 0) && ((eps % 6) == 3)), 0.000000f, select((((r_eps % 6) == 0) && ((eps % 6) == 2)), 0.000000f, select((((r_eps % 6) == 0) && ((eps % 6) == 1)), 0.000000f, select((((r_eps % 6) == 0) && ((eps % 6) == 0)), 4.000000f, 0.000000f)))))))))))))))))))))))))))))))))))))*select((((r_nu % 6) == 5) && ((nu % 6) == 5)), 1.000000f, select((((r_nu % 6) == 5) && ((nu % 6) == 4)), 0.000000f, select((((r_nu % 6) == 5) && ((nu % 6) == 3)), 0.000000f, select((((r_nu % 6) == 5) && ((nu % 6) == 2)), 0.000000f, select((((r_nu % 6) == 5) && ((nu % 6) == 1)), 0.000000f, select((((r_nu % 6) == 5) && ((nu % 6) == 0)), 0.000000f, select((((r_nu % 6) == 4) && ((nu % 6) == 5)), 0.000000f, select((((r_nu % 6) == 4) && ((nu % 6) == 4)), 1.000000f, select((((r_nu % 6) == 4) && ((nu % 6) == 3)), 1.000000f, select((((r_nu % 6) == 4) && ((nu % 6) == 2)), 1.000000f, select((((r_nu % 6) == 4) && ((nu % 6) == 1)), 1.000000f, select((((r_nu % 6) == 4) && ((nu % 6) == 0)), 1.000000f, select((((r_nu % 6) == 3) && ((nu % 6) == 5)), -5.000000f, select((((r_nu % 6) == 3) && ((nu % 6) == 4)), -2.000000f, select((((r_nu % 6) == 3) && ((nu % 6) == 3)), 2.000000f, select((((r_nu % 6) == 3) && ((nu % 6) == 2)), -1.000000f, select((((r_nu % 6) == 3) && ((nu % 6) == 1)), 1.000000f, select((((r_nu % 6) == 3) && ((nu % 6) == 0)), 0.000000f, select((((r_nu % 6) == 2) && ((nu % 6) == 5)), 0.000000f, select((((r_nu % 6) == 2) && ((nu % 6) == 4)), -1.000000f, select((((r_nu % 6) == 2) && ((nu % 6) == 3)), -1.000000f, select((((r_nu % 6) == 2) && ((nu % 6) == 2)), -4.000000f, select((((r_nu % 6) == 2) && ((nu % 6) == 1)), -4.000000f, select((((r_nu % 6) == 2) && ((nu % 6) == 0)), -5.000000f, select((((r_nu % 6) == 1) && ((nu % 6) == 5)), 4.000000f, select((((r_nu % 6) == 1) && ((nu % 6) == 4)), 2.000000f, select((((r_nu % 6) == 1) && ((nu % 6) == 3)), -2.000000f, select((((r_nu % 6) == 1) && ((nu % 6) == 2)), 4.000000f, select((((r_nu % 6) == 1) && ((nu % 6) == 1)), -4.000000f, select((((r_nu % 6) == 1) && ((nu % 6) == 0)), 0.000000f, select((((r_nu % 6) == 0) && ((nu % 6) == 5)), 0.000000f, select((((r_nu % 6) == 0) && ((nu % 6) == 4)), 0.000000f, select((((r_nu % 6) == 0) && ((nu % 6) == 3)), 0.000000f, select((((r_nu % 6) == 0) && ((nu % 6) == 2)), 0.000000f, select((((r_nu % 6) == 0) && ((nu % 6) == 1)), 0.000000f, select((((r_nu % 6) == 0) && ((nu % 6) == 0)), 4.000000f, 0.000000f))))))))))))))))))))))))))))))))))))))
}
}
}
}
}
}
}
produce output {
parallel (k.outer, 0, 64) {
// attr [M] storage_scope = "global"
allocate M[float32 * 9216]
// attr [M.global] storage_scope = "global"
allocate M.global[float32 * 36]
// attr [Y] storage_scope = "global"
allocate Y[float32 * 16]
produce M {
for (eps, 0, 6) {
for (nu, 0, 6) {
for (b.outer, 0, 256) {
M[((((eps*6) + nu)*256) + b.outer)] = 0.000000f
for (c.outer, 0, 128) {
M[((((eps*6) + nu)*256) + b.outer)] = (M[((((eps*6) + nu)*256) + b.outer)] + (U[((((((eps*6) + nu)*64) + k.outer)*128) + c.outer)]*V[((((((eps*6) + nu)*256) + b.outer)*128) + c.outer)]))
}
}
}
}
}
for (h.outer, 0, 16) {
for (w.outer, 0, 16) {
produce M.global {
for (ax0, 0, 6) {
for (ax1, 0, 6) {
M.global[((ax0*6) + ax1)] = M[((((ax0*6) + ax1)*256) + ((h.outer*16) + w.outer))]
}
}
}
produce Y {
unrolled (vh, 0, 4) {
unrolled (vw, 0, 4) {
Y[((vh*4) + vw)] = 0.000000f
unrolled (r_eps, 0, 6) {
unrolled (r_nu, 0, 6) {
Y[((vh*4) + vw)] = (Y[((vh*4) + vw)] + ((M.global[((r_eps*6) + r_nu)]*select((((r_eps % 6) == 5) && ((vh % 4) == 3)), 1.000000f, select((((r_eps % 6) == 5) && ((vh % 4) == 2)), 0.000000f, select((((r_eps % 6) == 5) && ((vh % 4) == 1)), 0.000000f, select((((r_eps % 6) == 5) && ((vh % 4) == 0)), 0.000000f, select((((r_eps % 6) == 4) && ((vh % 4) == 3)), -8.000000f, select((((r_eps % 6) == 4) && ((vh % 4) == 2)), 4.000000f, select((((r_eps % 6) == 4) && ((vh % 4) == 1)), -2.000000f, select((((r_eps % 6) == 4) && ((vh % 4) == 0)), 1.000000f, select((((r_eps % 6) == 3) && ((vh % 4) == 3)), 8.000000f, select((((r_eps % 6) == 3) && ((vh % 4) == 2)), 4.000000f, select((((r_eps % 6) == 3) && ((vh % 4) == 1)), 2.000000f, select((((r_eps % 6) == 3) && ((vh % 4) == 0)), 1.000000f, select((((r_eps % 6) == 2) && ((vh % 4) == 3)), -1.000000f, select((((r_eps % 6) == 2) && ((vh % 4) == 2)), 1.000000f, select((((r_eps % 6) == 2) && ((vh % 4) == 1)), -1.000000f, select((((r_eps % 6) == 2) && ((vh % 4) == 0)), 1.000000f, select((((r_eps % 6) == 1) && ((vh % 4) == 3)), 1.000000f, select((((r_eps % 6) == 1) && ((vh % 4) == 2)), 1.000000f, select((((r_eps % 6) == 1) && ((vh % 4) == 1)), 1.000000f, select((((r_eps % 6) == 1) && ((vh % 4) == 0)), 1.000000f, select((((r_eps % 6) == 0) && ((vh % 4) == 3)), 0.000000f, select((((r_eps % 6) == 0) && ((vh % 4) == 2)), 0.000000f, select((((r_eps % 6) == 0) && ((vh % 4) == 1)), 0.000000f, select((((r_eps % 6) == 0) && ((vh % 4) == 0)), 1.000000f, 0.000000f)))))))))))))))))))))))))*select((((r_nu % 6) == 5) && ((vw % 4) == 3)), 1.000000f, select((((r_nu % 6) == 5) && ((vw % 4) == 2)), 0.000000f, select((((r_nu % 6) == 5) && ((vw % 4) == 1)), 0.000000f, select((((r_nu % 6) == 5) && ((vw % 4) == 0)), 0.000000f, select((((r_nu % 6) == 4) && ((vw % 4) == 3)), -8.000000f, select((((r_nu % 6) == 4) && ((vw % 4) == 2)), 4.000000f, select((((r_nu % 6) == 4) && ((vw % 4) == 1)), -2.000000f, select((((r_nu % 6) == 4) && ((vw % 4) == 0)), 1.000000f, select((((r_nu % 6) == 3) && ((vw % 4) == 3)), 8.000000f, select((((r_nu % 6) == 3) && ((vw % 4) == 2)), 4.000000f, select((((r_nu % 6) == 3) && ((vw % 4) == 1)), 2.000000f, select((((r_nu % 6) == 3) && ((vw % 4) == 0)), 1.000000f, select((((r_nu % 6) == 2) && ((vw % 4) == 3)), -1.000000f, select((((r_nu % 6) == 2) && ((vw % 4) == 2)), 1.000000f, select((((r_nu % 6) == 2) && ((vw % 4) == 1)), -1.000000f, select((((r_nu % 6) == 2) && ((vw % 4) == 0)), 1.000000f, select((((r_nu % 6) == 1) && ((vw % 4) == 3)), 1.000000f, select((((r_nu % 6) == 1) && ((vw % 4) == 2)), 1.000000f, select((((r_nu % 6) == 1) && ((vw % 4) == 1)), 1.000000f, select((((r_nu % 6) == 1) && ((vw % 4) == 0)), 1.000000f, select((((r_nu % 6) == 0) && ((vw % 4) == 3)), 0.000000f, select((((r_nu % 6) == 0) && ((vw % 4) == 2)), 0.000000f, select((((r_nu % 6) == 0) && ((vw % 4) == 1)), 0.000000f, select((((r_nu % 6) == 0) && ((vw % 4) == 0)), 1.000000f, 0.000000f))))))))))))))))))))))))))
}
}
}
}
}
for (h.inner, 0, 4) {
for (w.inner, 0, 4) {
output[((((k.outer*64) + (h.inner + (h.outer*4)))*64) + (w.inner + (w.outer*4)))] = Y[(((((((((h.inner/4) + h.outer)*16) + ((w.inner/4) + w.outer)) - ((h.outer*16) + w.outer))*4) + (h.inner % 4))*4) + (w.inner % 4))]
}
}
}
}
}
}
Something probably went wrong for StorageRewrite after simplification.