Error when calling stage.bind()

I am trying to fuse a convolution and a max pooling to one kernel.

After stage.compute_at() is called, stage.bind() will cause error, otherwise it seems fine. Please check the attached code for details.

Traceback (most recent call last):
File “/home/xiaoquan.li/.PyCharmCE2018.1/config/scratches/scratch_1.py”, line 38, in
print(tvm.lower(s, [Input, Filter, Output], simple_mode=True))
File “/home/xiaoquan.li/tvm/python/tvm/build_module.py”, line 341, in lower
stmt = schedule.ScheduleOps(sch, bounds)
File “/home/xiaoquan.li/tvm/python/tvm/_ffi/_ctypes/function.py”, line 185, in call
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
File “/home/xiaoquan.li/tvm/python/tvm/_ffi/base.py”, line 68, in check_call
raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [21:53:59] …/src/op/op_util.cc:137: Check failed: is_zero(dom->min)

Stack trace returned 10 entries:
[bt] (0) /home/xiaoquan.li/tvm/build/libtvm.so(dmlc::StackTraceabi:cxx11+0x54) [0x7f830154c3cf]
[bt] (1) /home/xiaoquan.li/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x2a) [0x7f830154c6b6]
[bt] (2) /home/xiaoquan.li/tvm/build/libtvm.so(tvm::op::MakeLoopNest(tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hashtvm::IterVar, std::equal_totvm::IterVar, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, unsigned long, bool, std::unordered_set<tvm::IterVar, std::hashtvm::IterVar, std::equal_totvm::IterVar, std::allocatortvm::IterVar > const&, std::unordered_map<tvm::IterVar, HalideIR::Expr, std::hashtvm::IterVar, std::equal_totvm::IterVar, std::allocator<std::pair<tvm::IterVar const, HalideIR::Expr> > >, bool)+0x1877) [0x7f83018b9229]
[bt] (3) /home/xiaoquan.li/tvm/build/libtvm.so(tvm::ComputeLoopNest::make(tvm::ComputeOpNode const
, tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hashtvm::IterVar, std::equal_totvm::IterVar, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, bool)+0x18e) [0x7f8301886302]
[bt] (4) /home/xiaoquan.li/tvm/build/libtvm.so(tvm::MakeComputeStmt(tvm::ComputeOpNode const*, tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hashtvm::IterVar, std::equal_totvm::IterVar, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, bool)+0x70) [0x7f8301885037]
[bt] (5) /home/xiaoquan.li/tvm/build/libtvm.so(tvm::ComputeOpNode::BuildProvide(tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hashtvm::IterVar, std::equal_totvm::IterVar, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, bool) const+0x1dd) [0x7f830188611b]
[bt] (6) /home/xiaoquan.li/tvm/build/libtvm.so(tvm::schedule::MakePipeline(tvm::Stage const&, std::unordered_map<tvm::IterVar, tvm::Range, std::hashtvm::IterVar, std::equal_totvm::IterVar, std::allocator<std::pair<tvm::IterVar const, tvm::Range> > > const&, HalideIR::Internal::Stmt, bool)+0x8a) [0x7f83018caf52]
[bt] (7) /home/xiaoquan.li/tvm/build/libtvm.so(tvm::schedule::InjectAttach::Mutate(HalideIR::Internal::Stmt)+0x2ca) [0x7f83018ccfd8]
[bt] (8) /home/xiaoquan.li/tvm/build/libtvm.so(tvm::ir::IRMutator::Mutate_(HalideIR::Internal::AttrStmt const*, HalideIR::Internal::Stmt const&)+0xc9) [0x7f8301735097]
[bt] (9) /home/xiaoquan.li/tvm/build/libtvm.so(+0x14d9c47) [0x7f8301737c47]

import tvm

n = 10
Input = tvm.placeholder((n, n), name='Input')
Filter = tvm.placeholder((3, 3), name='Filter')
di = tvm.reduce_axis((0, 3), name='di')
dj = tvm.reduce_axis((0, 3), name='dj')

conv2d_output_shape = (n - 2, n - 2)

build_opencl=False

Output = tvm.compute(
    (n - 2, n - 2),
    lambda i, j: tvm.sum(Input[i + di, j + dj] * Filter[di, dj], axis=[di, dj]),
    name='Output')

pool_width = tvm.reduce_axis((0, 2), name="width")
pool_height = tvm.reduce_axis((0, 2),name="height")

pool_output = tvm.compute((n - 3, n - 3),
                          lambda i,j : tvm.max(Output[i + pool_width, j + pool_height], axis=[pool_width, pool_height]),
                          name='Pool')

s = tvm.create_schedule(pool_output.op)

#  Show original schedule
print(tvm.lower(s, [Input, Filter, Output], simple_mode=True))

s[Output].compute_at(s[pool_output], pool_output.op.axis[1])

s[Output].bind(s[Output].op.axis[0], tvm.thread_axis("blockIdx.x"))
s[Output].bind(s[Output].op.axis[1], tvm.thread_axis("threadIdx.x"))

s[pool_output].bind(s[pool_output].op.axis[0], tvm.thread_axis("blockIdx.x"))
s[pool_output].bind(s[pool_output].op.axis[1], tvm.thread_axis("threadIdx.y"))

print(tvm.lower(s, [Input, Filter, Output], simple_mode=True))

if build_opencl is True:
    fopencl = tvm.build(s, [Input, Filter, Output], 'opencl')
    print (fopencl.imported_modules[0].get_source())