Incorrect Boundary Infer when scheduling contain tvm_if_then_else

Hello, I found that when the input network has multiple convolution layers with pad operations, the memory allocated for the internal tensors are expanded to be larger than necessary, and the expansion accumulates and propagates from the output stage to the input stage.

As an example, assume a network has 3 convolution layers back to back. The python script is as follows:

import tvm
import topi

x = tvm.placeholder((1, 64, 56, 56))
w1 = tvm.placeholder((64, 64, 1, 1))
w2 = tvm.placeholder((64,64,3,3))
w3 = tvm.placeholder((64,64,3,3))
y1 = topi.nn.conv2d(x, w1, strides=1,padding=0, dilation=1)
y2 = topi.nn.conv2d(y1, w2, strides=1, padding=1, dilation=1)
y3 = topi.nn.conv2d(y2, w3, strides=1, padding=1, dilation=1)

s = tvm.create_schedule(y3.op)
stmt = tvm.lower(s, [x, w1, w2, w3, y1, y2, y3], simple_mode=True)

print(stmt)

The output shows that a memory “pad_temp” with the size of [1*64*60*60] (=230400) is allocated and reused for the 3 layers. However, there is some wastage. The ideal size of “pad_temp” should be [1*64*58*58].

// attr [pad_temp] storage_scope = "global"
allocate pad_temp[float32 * 230400]
produce pad_temp {
  for (i1, 0, 64) {
    for (i2, 0, 60) {
      for (i3, 0, 60) {
        if (likely((2 <= i2))) {
          if (likely((i2 < 58))) {
            if (likely((2 <= i3))) {
              if (likely((i3 < 58))) {
                **pad_temp[((((i1*60) + i2)*60) + i3)] = placeholder[(((((i1*56) + i2)*56) + i3) + -114)]**
              }
            }
          }
        }
      }
    }
  }
}
produce compute {
  for (ff, 0, 64) {
    for (yy, 0, 60) {
      for (xx, 0, 60) {
        compute[(((((ff*56) + yy)*56) + xx) + -114)] = 0.000000f
        for (rc, 0, 64) {
          if (likely((2 <= yy))) {
            if (likely((yy < 58))) {
              if (likely((2 <= xx))) {
                if (likely((xx < 58))) {
                  compute[(((((ff*56) + yy)*56) + xx) + -114)] = (compute[(((((ff*56) + yy)*56) + xx) + -114)] + (pad_temp[(((yy*60) + xx) + (rc*3600))]*placeholder[((ff*64) + rc)]))
                }
              }
            }
          }
        }
      }
    }
  }
}
produce pad_temp {
  for (i1, 0, 64) {
    for (i2, 0, 60) {
      for (i3, 0, 60) {
        if (likely((1 <= i2))) {
          if (likely((i2 < 59))) {
            if (likely((1 <= i3))) {
              if (likely((i3 < 59))) {
                pad_temp[((((i1*60) + i2)*60) + i3)] = tvm_if_then_else(((((2 <= i2) && (i2 < 58)) && (2 <= i3)) && (i3 < 58)), compute[(((((i1*56) + i2)*56) + i3) + -114)], 0.000000f)
              }
            }
          }
        }
      }
    }
  }
}
produce compute {
  for (ff, 0, 64) {
    for (yy, 0, 58) {
      for (xx, 0, 58) {
        compute[(((((ff*56) + yy)*56) + xx) + -57)] = 0.000000f
        for (rc, 0, 64) {
          for (ry, 0, 3) {
            for (rx, 0, 3) {
              if (likely((1 <= yy))) {
                if (likely((yy < 57))) {
                  if (likely((1 <= xx))) {
                    if (likely((xx < 57))) {
                      compute[(((((ff*56) + yy)*56) + xx) + -57)] = (compute[(((((ff*56) + yy)*56) + xx) + -57)] + (pad_temp[(((((yy*60) + xx) + (rc*3600)) + (ry*60)) + rx)]*placeholder[((((((ff*64) + rc)*3) + ry)*3) + rx)]))
                    }
                  }
                }
              }
            }
          }
        }
      }
    }
  }
}
produce pad_temp {
  for (i1, 0, 64) {
    for (i2, 0, 58) {
      for (i3, 0, 58) {
        pad_temp[((((i1*58) + i2)*58) + i3)] = tvm_if_then_else(((((1 <= i2) && (i2 < 57)) && (1 <= i3)) && (i3 < 57)), compute[(((((i1*56) + i2)*56) + i3) + -57)], 0.000000f)
      }
    }
  }
}
produce compute {
  for (ff, 0, 64) {
    for (yy, 0, 56) {
      for (xx, 0, 56) {
        compute[((((ff*56) + yy)*56) + xx)] = 0.000000f
        for (rc, 0, 64) {
          for (ry, 0, 3) {
            for (rx, 0, 3) {
              compute[((((ff*56) + yy)*56) + xx)] = (compute[((((ff*56) + yy)*56) + xx)] + (pad_temp[(((((yy*58) + xx) + (rc*3364)) + (ry*58)) + rx)]*placeholder[((((((ff*64) + rc)*3) + ry)*3) + rx)]))
            }
          }
        }
      }
    }
  }
}

With further inspection, we found that this is because in the “InferBound” pass, when it infers the tensor domain in the function “PropBoundToInputs” (in bound.cc), it only uses the domain of iteration variables, but does not consider the constraints in the tvm_if_then_else condition. Therefore, the inferred bound is larger than needed.

for example (pseudo tvm halide IR)

for(ax0, 0, 58){
    a[ax0] = tvm_if_then_else((ax0>1) && (ax0<57), b[ax0-1], 0)
}

b’s range would be infered as [-1, 57] (hence rebased to [0,58]) by TVM, but b really is [0, 56] if the if_then_else condition is considered.

Since the InferBound is traversing the graph from the output stage to input stage, backwards, the tensor allocated for the first stage (and possibly reused by the following) would be larger than necessary, which is reflected by the example.

This would be a more concerning problem when the input network has a large number of layers and with pad operation, because it will introduce quite some memory wastage. However, this type of networks arereally common (e.g., resnet, vgg, etc).

Kindly let me know whether my understand is correct. I wonder whether this problem has ever been encountered/observed when applying to TVM on DNNs, how was it solved?

Currently, we made a fix locally to use the intersect of the condition in the if_then_else with the iteration variable domain, to infer the output tensor. But I would really like to hear about your solutions.

Thanks

Most of the cases tvm stack runs deep neural network through two level pass, the high-level pass will cut the computation graph into pieces, then the low level tensor code generator generate code for each fused operator. So if you run it through relay, the same problem wouldn’t occur

It’s really helpful for me. Thanks for your question. I also try to use lambda expression to build a NN instead of a single layer, and meet the same bug as yours. I think it is possible to solve this bug instead of use Relay so that we can get more flexible kernel. What is your point?

1 Like

@wezuo
I think TVM should support writing multiple layers as a single kernel (e.g. enable fusing two conv2d in the future)
You could contribute your fix.

Would you mind expanding on this?
How is the Relay definition of that same network not also going to lead to the same problems?

Yes, that’s what I think, will raise a PR soon.

Hello,

I also encountered problems as the ones you described in your post by building schedules with more than one conv2d with padding = 'same'.
I opened a new topic [TVM Scheduling] Split factors missmatch to original extent and therefore would like to know more about your fix

  1. Is there an ETA available for this PR?
  2. Where exactly did you do the change? is it inside the InferBounds routine or is it an added ir_pass which makes changes based on the AST?
  3. Also I was wondering if the change could work also for the case from my post copied below for ease of lecture:

Thanks

TVM don’t support cross layer kernel, Relay FuseOps Pass will cut the computation graph into pieces with one single conv2d/dense/etc,so the boundary problem would not occur. But if you fuse multiple conv2d/dense/other main nn ops by your own FuseOps Pass or OPPattern, then the tvm infer bound problem will occur on tvm_if_then_else/padding/deconvolution cases. I wondered if tvm will support multiple conv2d/dense/… fused kernel in the future?

1 Like

@wezuo Hi, have you already submit the PR?

Thanks