The question is: is the testcase below expected to work, or is it some known limitation?
Take this example:
import tvm
xa = tvm.var('xa', 'int32')
ya = tvm.var('ya', 'int32')
A = tvm.placeholder((xa,ya), dtype='int32', name='A')
xb = tvm.var('xb', 'int32')
yb = tvm.var('yb', 'int32')
B = tvm.placeholder((xb,yb), dtype='int32', name='B')
xc = tvm.var('xc', 'int32')
yc = tvm.var('yc', 'int32')
C = tvm.placeholder((xc,yc), dtype='int32', name='C')
xo = A.shape[0]
yo = (A.shape[1]+B.shape[1]) // C.shape[1]
O = tvm.compute((xo,yo), lambda i,j: 0, name='O')
s = tvm.create_schedule(O.op)
f = tvm.build(s, [O, A, B, C], target = 'llvm')
The idea is to pass the tensors alone (without passing the x_
/y_
variables explicitly), and have the output shape information extracted from the input tensors instead. However, I’m getting an assert:
[16:58:10] /w/src/dmlc/tvm/src/codegen/llvm/codegen_llvm.cc:670: Check failed: it != var_map_.end(): cannot find variable yc
Stack trace:
[bt] (0) /w/src/dmlc/tvm/build.x86/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x33) [0x7fc458aa86b3]
[bt] (1) /w/src/dmlc/tvm/build.x86/libtvm.so(tvm::codegen::CodeGenLLVM::GetVarValue(tvm::Variable const*) const+0x1c3) [0x7fc459360313]
[bt] (2) /w/src/dmlc/tvm/build.x86/libtvm.so(tvm::NodeFunctor<llvm::Value* (tvm::runtime::ObjectRef const&, tvm::ir::ExprFunctor<llvm::Value* (tvm::Expr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::ir::ExprFunctor<llvm::Value* (tvm::Expr const&)>*) const+0xf4) [0x7fc45933f704]
[bt] (3) /w/src/dmlc/tvm/build.x86/libtvm.so(tvm::codegen::CodeGenLLVM::VisitExpr_(tvm::ir::GE const*)+0x1f) [0x7fc45936764f]
[bt] (4) /w/src/dmlc/tvm/build.x86/libtvm.so(tvm::NodeFunctor<llvm::Value* (tvm::runtime::ObjectRef const&, tvm::ir::ExprFunctor<llvm::Value* (tvm::Expr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::ir::ExprFunctor<llvm::Value* (tvm::Expr const&)>*) const+0xf4) [0x7fc45933f704]
[bt] (5) /w/src/dmlc/tvm/build.x86/libtvm.so(tvm::codegen::CodeGenLLVM::VisitExpr_(tvm::ir::And const*)+0x1f) [0x7fc459367b1f]
[bt] (6) /w/src/dmlc/tvm/build.x86/libtvm.so(tvm::NodeFunctor<llvm::Value* (tvm::runtime::ObjectRef const&, tvm::ir::ExprFunctor<llvm::Value* (tvm::Expr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::ir::ExprFunctor<llvm::Value* (tvm::Expr const&)>*) const+0xf4) [0x7fc45933f704]
[bt] (7) /w/src/dmlc/tvm/build.x86/libtvm.so(tvm::codegen::CodeGenLLVM::VisitExpr_(tvm::ir::Or const*)+0x1f) [0x7fc459367b7f]
[bt] (8) /w/src/dmlc/tvm/build.x86/libtvm.so(tvm::NodeFunctor<llvm::Value* (tvm::runtime::ObjectRef const&, tvm::ir::ExprFunctor<llvm::Value* (tvm::Expr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::ir::ExprFunctor<llvm::Value* (tvm::Expr const&)>*) const+0xf4) [0x7fc45933f704]
Aborted
When I dump the LoweredFunc right before codegen, it is:
default_function(handle args, handle arg_type_ids, int32 num_args) {
assert((num_args == 4), "default_function: num_args should be 4")
let arg0 = tvm_struct_get(args, 0, 12)
let arg0.code = arg_type_ids[0]
let arg1 = tvm_struct_get(args, 1, 12)
let arg1.code = arg_type_ids[1]
let arg2 = tvm_struct_get(args, 2, 12)
let arg2.code = arg_type_ids[2]
let arg3 = tvm_struct_get(args, 3, 12)
let arg3.code = arg_type_ids[3]
let O = tvm_struct_get(arg0, 0, 1)
// attr [O] storage_alignment = 128
let arg0.shape = tvm_struct_get(arg0, 0, 2)
let xa = int32(arg0.shape[0])
let arg0.strides = tvm_struct_get(arg0, 0, 3)
let stride = tvm_if_then_else((select((((yc >= 0) && (((ya + yb) % yc) >= 0)) || ((yc < 0) && (((ya + yb) % yc) <= 0))), ((ya + yb)/yc), (((ya + yb)/yc) - 1)) == 1), 0, tvm_if_then_else(tvm_handle_is_null(arg0.strides), 1, int32(arg0.strides[1])))
let stride = tvm_if_then_else((xa == 1), 0, tvm_if_then_else(tvm_handle_is_null(arg0.strides), select((((yc >= 0) && (((ya + yb) % yc) >= 0)) || ((yc < 0) && (((ya + yb) % yc) <= 0))), ((ya + yb)/yc), (((ya + yb)/yc) - 1)), int32(arg0.strides[0])))
let dev_type = tvm_struct_get(arg0, 0, 10)
let dev_id = tvm_struct_get(arg0, 0, 9)
let A = tvm_struct_get(arg1, 0, 1)
// attr [A] storage_alignment = 128
let arg1.shape = tvm_struct_get(arg1, 0, 2)
let ya = int32(arg1.shape[1])
let arg1.strides = tvm_struct_get(arg1, 0, 3)
let stride = tvm_if_then_else((ya == 1), 0, tvm_if_then_else(tvm_handle_is_null(arg1.strides), 1, int32(arg1.strides[1])))
let stride = tvm_if_then_else((xa == 1), 0, tvm_if_then_else(tvm_handle_is_null(arg1.strides), ya, int32(arg1.strides[0])))
let B = tvm_struct_get(arg2, 0, 1)
// attr [B] storage_alignment = 128
let arg2.shape = tvm_struct_get(arg2, 0, 2)
let xb = int32(arg2.shape[0])
let yb = int32(arg2.shape[1])
let arg2.strides = tvm_struct_get(arg2, 0, 3)
let stride = tvm_if_then_else((yb == 1), 0, tvm_if_then_else(tvm_handle_is_null(arg2.strides), 1, int32(arg2.strides[1])))
let stride = tvm_if_then_else((xb == 1), 0, tvm_if_then_else(tvm_handle_is_null(arg2.strides), yb, int32(arg2.strides[0])))
let C = tvm_struct_get(arg3, 0, 1)
// attr [C] storage_alignment = 128
let arg3.shape = tvm_struct_get(arg3, 0, 2)
let xc = int32(arg3.shape[0])
let yc = int32(arg3.shape[1])
let arg3.strides = tvm_struct_get(arg3, 0, 3)
let stride = tvm_if_then_else((yc == 1), 0, tvm_if_then_else(tvm_handle_is_null(arg3.strides), 1, int32(arg3.strides[1])))
let stride = tvm_if_then_else((xc == 1), 0, tvm_if_then_else(tvm_handle_is_null(arg3.strides), yc, int32(arg3.strides[0])))
assert(((((arg0.code == 3) || (arg0.code == 13)) || (arg0.code == 7)) || (arg0.code == 4)), "default_function: Expect arg[0] to be pointer")
assert(((((arg1.code == 3) || (arg1.code == 13)) || (arg1.code == 7)) || (arg1.code == 4)), "default_function: Expect arg[1] to be pointer")
assert(((((arg2.code == 3) || (arg2.code == 13)) || (arg2.code == 7)) || (arg2.code == 4)), "default_function: Expect arg[2] to be pointer")
assert(((((arg3.code == 3) || (arg3.code == 13)) || (arg3.code == 7)) || (arg3.code == 4)), "default_function: Expect arg[3] to be pointer")
assert((dev_type == 1), "device_type need to be 1")
assert((2 == tvm_struct_get(arg0, 0, 4)), "arg0.ndim is expected to equal 2")
assert((((tvm_struct_get(arg0, 0, 5) == (uint8)0) && (tvm_struct_get(arg0, 0, 6) == (uint8)32)) && (tvm_struct_get(arg0, 0, 7) == (uint16)1)), "arg0.dtype is expected to be int32")
assert((select((((yc >= 0) && (((ya + yb) % yc) >= 0)) || ((yc < 0) && (((ya + yb) % yc) <= 0))), ((ya + yb)/yc), (((ya + yb)/yc) - 1)) == int32(arg0.shape[1])), "Argument arg0.shape[1] has an unsatisfied constraint")
assert(((uint64)0 == tvm_struct_get(arg0, 0, 8)), "Argument arg0.byte_offset has an unsatisfied constraint")
assert((2 == tvm_struct_get(arg1, 0, 4)), "arg1.ndim is expected to equal 2")
assert((((tvm_struct_get(arg1, 0, 5) == (uint8)0) && (tvm_struct_get(arg1, 0, 6) == (uint8)32)) && (tvm_struct_get(arg1, 0, 7) == (uint16)1)), "arg1.dtype is expected to be int32")
assert((xa == int32(arg1.shape[0])), "Argument arg1.shape[0] has an unsatisfied constraint")
assert(((uint64)0 == tvm_struct_get(arg1, 0, 8)), "Argument arg1.byte_offset has an unsatisfied constraint")
assert((1 == tvm_struct_get(arg1, 0, 10)), "Argument arg1.device_type has an unsatisfied constraint")
assert((dev_id == tvm_struct_get(arg1, 0, 9)), "Argument arg1.device_id has an unsatisfied constraint")
assert((2 == tvm_struct_get(arg2, 0, 4)), "arg2.ndim is expected to equal 2")
assert((((tvm_struct_get(arg2, 0, 5) == (uint8)0) && (tvm_struct_get(arg2, 0, 6) == (uint8)32)) && (tvm_struct_get(arg2, 0, 7) == (uint16)1)), "arg2.dtype is expected to be int32")
assert(((uint64)0 == tvm_struct_get(arg2, 0, 8)), "Argument arg2.byte_offset has an unsatisfied constraint")
assert((1 == tvm_struct_get(arg2, 0, 10)), "Argument arg2.device_type has an unsatisfied constraint")
assert((dev_id == tvm_struct_get(arg2, 0, 9)), "Argument arg2.device_id has an unsatisfied constraint")
assert((2 == tvm_struct_get(arg3, 0, 4)), "arg3.ndim is expected to equal 2")
assert((((tvm_struct_get(arg3, 0, 5) == (uint8)0) && (tvm_struct_get(arg3, 0, 6) == (uint8)32)) && (tvm_struct_get(arg3, 0, 7) == (uint16)1)), "arg3.dtype is expected to be int32")
assert(((uint64)0 == tvm_struct_get(arg3, 0, 8)), "Argument arg3.byte_offset has an unsatisfied constraint")
assert((1 == tvm_struct_get(arg3, 0, 10)), "Argument arg3.device_type has an unsatisfied constraint")
assert((dev_id == tvm_struct_get(arg3, 0, 9)), "Argument arg3.device_id has an unsatisfied constraint")
0
// attr [0] compute_scope = "default_function_compute_"
produce O {
for (i, 0, xa) {
for (j, 0, select((((yc >= 0) && (((ya + yb) % yc) >= 0)) || ((yc < 0) && (((ya + yb) % yc) <= 0))), ((ya + yb)/yc), (((ya + yb)/yc) - 1))) {
O[((i*stride) + (j*stride))] = 0
}
}
}
}
You can see that yc
is indeed accessed there, before it’s assigned (although it maybe just a different variable with the same name).