I’m trying to understand how the TVM buffer array layout optimizations work. I’m trying to pattern match a matrix multiply operation within a 2D convolution using tensorize with the following parameters:
It is a 3x3 convolution.(input explicitly padded)
minibatch =1, output_height=output_width=28, input_height=input_width=30, pad_height=pad_width=0, input_channel = output_channel =128, kernel_height=kernel_width=3
My problem is with the input binding when tensorize is called. I have declared the input buffer using explicit strides but binding throws an error saying there is an unmet assertion on the stride of the leading dimension of the input buffer. I think I’m passing the stride correctly, I believe that there is some buffer indexing/layout optimization under the hood(like compaction) that is changing the layout.
The problem is here:
I have specified the stride on the input dimension leading dimension as (30x30x64)
as follows:
yy_ptr = tvm.decl_buffer(B.shape, B.dtype,
name="some", offset_factor=1,strides=[30*30*64, 30*64, 64, 1],
data_alignment=64)
but when binding is done it complains that some.strides[0] has an unmet assertion
This goes away when I declare as follows:
yy_ptr = tvm.decl_buffer(B.shape, B.dtype,
name="some", offset_factor=1,strides=[3*30*64, 30*64, 64, 1],
data_alignment=64)
which seems to indicate some layout optimization under the hood. As the input height dimension is seemingly shrunk to the filter height.
The full code is below:
import tvm
def test_conv():
A1 = tvm.placeholder((1,2,30,30,64), name='input')
W1 = tvm.placeholder((2,2,3,3,64,64), name='weight')
rco1 = tvm.reduce_axis((0, 2), name='rco1')
ry1 = tvm.reduce_axis((0, 3), name='ry1')
rx1 = tvm.reduce_axis((0, 3), name='rx1')
rci1 = tvm.reduce_axis((0, 64), name='rci1')
stride_height = 1
stride_width = 1
B1 = tvm.compute((1,2,28,28, 64),
lambda nn,ff,yy, xx, vlen1: tvm.sum(
W1[ff,rco1,ry1,rx1,rci1,vlen1] * A1[nn, rco1, ry1 + stride_height*yy, rx1 + stride_width*xx,rci1],
axis=[rco1,ry1, rx1, rci1]), name='output')
s = tvm.create_schedule(B1.op)
n,ko,h,w,ki = s[B1].op.axis
rco,ry,rx, rci = s[B1].op.reduce_axis
w_factor_inner = 28
tile_c = 1
tile_h = 2
wo, wi = s[B1].split(w, w_factor_inner)
ho, hi = s[B1].split(h, tile_h)
rco_o, rco_i = s[B1].split(rco, tile_c)
s[B1].reorder(n,ko,rco_o,wo,ho,hi,rco_i,ry,rx,wi,ki,rci)
#print(tvm.lower(s, [W1, A1, B1], simple_mode=True))
def intrin():
A = tvm.placeholder((1,3,3,64,64), name='w')
B = tvm.placeholder((1,3,30,64), name='b')
k = tvm.reduce_axis((0, 64), name='k')
k_outer = tvm.reduce_axis((0, 1), name='k_outer')
ry = tvm.reduce_axis((0, 3), name='ry')
rx = tvm.reduce_axis((0, 3), name='rx')
stride_width = 1
C = tvm.compute(
(28,64),
lambda m,n: tvm.sum(A[k_outer,ry,rx,k,n] * B[k_outer,ry, rx + m*stride_width,k], axis=[k_outer,ry,rx,k]),
name='out')
s1 = tvm.create_schedule(C.op)
w,ofm = s1[C].op.axis
kco,ky,kx,kci = s1[C].op.reduce_axis
s1[C].reorder(kco,ky,kx,w,ofm,kci)
xx_ptr = tvm.decl_buffer(A.shape, A.dtype,
name="W",offset_factor=1,
data_alignment=64)
yy_ptr = tvm.decl_buffer(B.shape, B.dtype,
name="some", offset_factor=1,strides=[30*30*64, 30*64, 64, 1],
data_alignment=64)
zz_ptr = tvm.decl_buffer(C.shape, C.dtype,
name="OUT",offset_factor=1,
data_alignment=64)
def intrin_func(ins, outs):
body = tvm.call_extern ("int32","dummy", ins[0].access_ptr("r"),ins[1].access_ptr("r"),outs[0].access_ptr("w"))
return body, None, body
with tvm.build_config(data_alignment=64):
return tvm.decl_tensor_intrin(C.op, intrin_func, name="GEMM",
binds={A: xx_ptr,
B: yy_ptr,
C: zz_ptr})
tensorized = intrin()
s[B1].tensorize(rco_i, tensorized)
print(tvm.lower(s, [W1, A1, B1], simple_mode=True))
if __name__ == "__main__":
test_conv()