How to control data layout of intermediate values

Hey, I am working on a simple box blur mainly to write an introductory blog post on TVM. It has four stages, padding, x-blur, y-blur, and a final cast back to uint8 datatype.

The y-blur accesses the x-blur result at 3 different y increments, so I would like to layout x_blur such that Y is the most rapidly changing dimension. I have been successful at this by changing that actual algorithm definition but I would like to do this during scheduling.

I am trying to use bind_buffer to control the stride of x_blur, but binding x_blur to any decl_buffer object causes TVM to think that I want to pass in x_blur to the function which is definitely not what I want:

chrisn@chrisn-lt-2:~/dev/tvm-clj/python/questions$ cat bind_buffer.py
import tvm


def print_schedule(sched, arglist):
    print(tvm.lower(sched, arglist, simple_mode=True))


rows = tvm.var("rows")
cols = tvm.var("cols")
chans = tvm.var("chans")

input_vec = tvm.placeholder((rows,cols,chans), dtype="float32", name="input")
clamp = lambda v, v_min, v_max: tvm.max( tvm.min(v, v_max), v_min )
## clamp to edge padding
padded = tvm.compute((rows+2,cols+2,chans)
                     , lambda y, x, c: input_vec[clamp(y-1, 0, rows-1)
                                                 , clamp(x-1, 0, cols-1)
                                                 , c].astype("uint16")
                     , name="padded")



x_blur = tvm.compute((rows+2, cols, chans)
                     , lambda y, x, c: (padded[y,x,c] +
                                        padded[y,x+1,c] +
                                        padded[y,x+2,c]) / 3
                     , name="x_blur")

y_blur = tvm.compute((rows, cols, chans)
                     , lambda y, x, c: (x_blur[y,x,c] +
                                        x_blur[y+1,x,c] +
                                        x_blur[y+2,x,c]) / 3
                     , name="y_blur")

box_blur = tvm.compute((rows,cols,chans)
                       , lambda y, x, c: y_blur[y,x,c].astype("uint8")
                       , name="box_blur")

arglist = [input_vec, box_blur]

schedule = tvm.create_schedule(box_blur.op)
schedule[padded.op].compute_inline()
schedule[y_blur].compute_inline()
schedule[x_blur].compute_at(schedule[box_blur], box_blur.op.axis[1])
print_schedule(schedule, arglist)

x_blur_y_stride = 1
x_blur_c_stride = rows + 2
x_blur_x_stride = x_blur_c_stride * 3

fun = tvm.build(schedule, arglist, "llvm", name="box_blur"
                , binds={x_blur: tvm.decl_buffer(x_blur.shape
                                                 , name="x_blur"
                                                 , scope="local"
                                                 , dtype=x_blur.dtype
                                                 , strides=[x_blur_y_stride,
                                                            x_blur_x_stride,
                                                            x_blur_c_stride])})
chrisn@chrisn-lt-2:~/dev/tvm-clj/python/questions$ python3 bind_buffer.py
[14:14:18] /home/chrisn/dev/tvm-clj/tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
[14:14:18] /home/chrisn/dev/tvm-clj/tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Load
[14:14:18] /home/chrisn/dev/tvm-clj/tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Load
[14:14:18] /home/chrisn/dev/tvm-clj/tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Load
// attr [x_blur] storage_scope = "global"
allocate x_blur[int32 * 3 * 1 * chans]
produce box_blur {
  for (y, 0, rows) {
    for (x, 0, cols) {
      produce x_blur {
        for (y, 0, 3) {
          for (c, 0, chans) {
            x_blur[((y*chans) + c)] = (int32(((uint16(input[(((max((min(x, cols) + -1), 0) + (max((min((y + y), rows) + -1), 0)*cols))*chans) + c)]) + uint16(input[(((max(min(x, (cols + -1)), 0) + (max((min((y + y), rows) + -1), 0)*cols))*chans) + c)])) + uint16(input[(((max(min((x + 1), (cols + -1)), 0) + (max((min((y + y), rows) + -1), 0)*cols))*chans) + c)])))/3)
          }
        }
      }
      for (c, 0, chans) {
        box_blur[((((y*cols) + x)*chans) + c)] = uint8((((x_blur[c] + x_blur[(chans + c)]) + x_blur[((chans*2) + c)])/3))
      }
    }
  }
}

[14:14:18] /home/chrisn/dev/tvm-clj/tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Cast
[14:14:18] /home/chrisn/dev/tvm-clj/tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Load
[14:14:18] /home/chrisn/dev/tvm-clj/tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Load
[14:14:18] /home/chrisn/dev/tvm-clj/tvm/src/arithmetic/int_set.cc:514: cannot evaluate set type Load
Traceback (most recent call last):
  File "bind_buffer.py", line 58, in <module>
    x_blur_c_stride])})
  File "/home/chrisn/.local/lib/python3.6/site-packages/tvm-0.5.dev0-py3.6-linux-x86_64.egg/tvm/build_module.py", line 445, in build
    binds=binds)
  File "/home/chrisn/.local/lib/python3.6/site-packages/tvm-0.5.dev0-py3.6-linux-x86_64.egg/tvm/build_module.py", line 380, in lower
    return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
  File "/home/chrisn/.local/lib/python3.6/site-packages/tvm-0.5.dev0-py3.6-linux-x86_64.egg/tvm/_ffi/_ctypes/function.py", line 185, in __call__
    ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
  File "/home/chrisn/.local/lib/python3.6/site-packages/tvm-0.5.dev0-py3.6-linux-x86_64.egg/tvm/_ffi/base.py", line 66, in check_call
    raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [14:14:18] /home/chrisn/dev/tvm-clj/tvm/src/pass/make_api.cc:169: Not all Vars are passed in api_args:  'x_blur'  does not appeared in api_args

Is there a way, at scheduling time, to dictate the layout of intermediate buffers (ones that are allocated/deallocated by tvm during the course of its execution)?

I think this is a useful trick that worth a tutorial.

The example shows how to transpose the layout of an intermediate buffer

import tvm

n = 10
m = 20

A = tvm.placeholder((n, m), name='A')
B = tvm.compute((n, m), lambda i, j: A[i][j], name='B')
C = tvm.compute((n, m), lambda i, j: B[i][j], name='C')

s = tvm.create_schedule([C.op])
print(tvm.lower(s, [A, C], simple_mode=True))

print("======================================\n")


i, j = s[B].op.axis
s[B].reorder(j, i)     # transpose
BB = s.cache_write(B, 'global')
s[B].compute_inline()

print(tvm.lower(s, [A, C], simple_mode=True))

output

// attr [B] storage_scope = "global"
allocate B[float32 * 10 * 20]
produce B {
  for (i, 0, 10) {
    for (j, 0, 20) {
      B[((i*20) + j)] = A[((i*20) + j)]
    }
  }
}
produce C {
  for (i, 0, 10) {
    for (j, 0, 20) {
      C[((i*20) + j)] = B[((i*20) + j)]
    }
  }
}

======================================

// attr [B.global] storage_scope = "global"
allocate B.global[float32 * 20 * 10]
produce B.global {                     // B.global is transposed
  for (j.c, 0, 20) {
    for (i.c, 0, 10) {
      B.global[((j.c*10) + i.c)] = A[(j.c + (i.c*20))]
    }
  }
}
produce C {
  for (i, 0, 10) {
    for (j, 0, 20) {
      C[((i*20) + j)] = B.global[(i + (j*10))]
    }
  }
}

That is a one step further. There is another rub, however.

The proposed solution links the computation order of B to its storage. In my case, I would like X to be computer in [y,x,c] order but stored in [x,c,y] order. Then y_blur will access x_blur using normal [y,x,c] coordinates.

So, without changing the computation order of the stage can I specify the layout of the cache_write object? Basically, if I access the source image out of order that is as expensive as accessing x_blur out of order. I would like to access the source image in order, then write to x_blur out of order letting the memory subsystem do whatever. Nothing is dependent on those writes for a long time.

Then read x_blur out-of-order again for y_blur. In my case in-order means [y,x,c] and out-of-order means [x,c,y]. Does this make sense? So it really is just changing the storage layout of x_blur and nothing else about the algorithm; including iteration order across the source image.

The simplest solution so far is just to have the out-of-order-ness of x_blur be part of the algorithm and then reorder the x_blur axis during scheduling. This is, however, baking optimization into the algorithm itself.

You can schedule the object created by cache_write too.
Use cache_write to change layout and use reorder, compute_at to change compute order.
Can this handle your case? Otherwise I have no solution.

import tvm

n = 10
m = 20

A = tvm.placeholder((n, m), name='A')
B = tvm.compute((n, m), lambda i, j: A[i][j], name='B')
C = tvm.compute((n, m), lambda i, j: B[i][j], name='C')

s = tvm.create_schedule([C.op])
print(tvm.lower(s, [A, C], simple_mode=True))

print("======================================\n")


i, j = s[B].op.axis
s[B].reorder(j, i)
BB = s.cache_write(B, 'global')
s[B].compute_inline()

j, i = s[BB].op.axis
s[BB].reorder(i, j)

print(tvm.lower(s, [A, C], simple_mode=True))

output

// attr [B] storage_scope = "global"
allocate B[float32 * 10 * 20]
produce B {
  for (i, 0, 10) {
    for (j, 0, 20) {
      B[((i*20) + j)] = A[((i*20) + j)]
    }
  }
}
produce C {
  for (i, 0, 10) {
    for (j, 0, 20) {
      C[((i*20) + j)] = B[((i*20) + j)]
    }
  }
}

======================================

// attr [B.global] storage_scope = "global"
allocate B.global[float32 * 20 * 10]
produce B.global {
  for (i.c, 0, 10) {
    for (j.c, 0, 20) {
      B.global[(i.c + (j.c*10))] = A[((i.c*20) + j.c)]
    }
  }
}
produce C {
  for (i, 0, 10) {
    for (j, 0, 20) {
      C[((i*20) + j)] = B.global[(i + (j*10))]
    }
  }
}
1 Like

It certainly might. I will try it (and any version of it I can think of) and get back to you.

1 Like

It worked exactly as advertised! Thank you. This compiler is pretty fun to work with.