Seems a relay infer_type bug or Convolution operation unreasonable efforce input data types matching

I am working on the quantize algorithm recently, but I meet the follwing bug, can any one help me answer it? thanks!

Description: for the convolution case, sometimes I want the input quantize from float32 to uint8 and weight from float32 to int8, unfortunately, the following infer_type bug appeared, because I am not familiar with the ir_pass.infer_type source code, how can I solve the bug, thanks.

Log: tvm/src/relay/pass/type_solver.cc:99: Check failed: resolved.defined() Unable to unify parent types: TensorType([64, 3, 7, 7], uint8) and TensorType([64, 3, 7, 7], int8)

A guess is that you may need to change the backing code that defines the ops to match the data types that you want to use (e.g., https://github.com/dmlc/tvm/blob/5a27632e274fff57087ed0b6eb2856b6e5946cfb/src/relay/pass/quantize.cc#L281).

The bug is probably not in infer_type, as this is just enforcing that data types match when they should.

Thank you for your replay.
As you said, I can change the backing code to change the data types, I can also achieve this by define cfg->dtype_weight in the quantize config initialize phase.
Anyway, the reason why we have this requirement is because we want to use uint8 instead of int8 when the conv operation behind the relu layer is quantized. And the convolution weight use int8.
The bug is probably not in infer_type, but why convolution operation sometime enforcing the data types matching?

Ok, so then the problem may be because of existing type relations defined for ops like Conv2D which may enforce that the kernel dtype is the same as the data dtype. You may try changing https://github.com/dmlc/tvm/blob/5a27632e274fff57087ed0b6eb2856b6e5946cfb/src/relay/op/nn/convolution.cc#L87

to see if that triggers the error but I am not sure how to fix this.

As your guide, after changing the code from " reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));" to “reporter->Assign(types[1], TensorTypeNode::make(wshape, weight->dtype));”, the infer_type can pass, but it appears a new error, just like below. I will create a new question, thank you very much!

Log:
Traceback (most recent call last):
File “/home/ai/solomon/workspace/code/tvm/python/tvm/relay/backend/compile_engine.py”, line 76, in lower
return _backend._CompileEngineLower(self, key)
File “/home/ai/solomon/workspace/code/tvm/python/tvm/_ffi/_ctypes/function.py”, line 185, in call
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
File “/home/ai/solomon/workspace/code/tvm/python/tvm/_ffi/base.py”, line 71, in check_call
raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: TVMCall CFunc Error:
Traceback (most recent call last):
File “/home/ai/solomon/workspace/code/tvm/python/tvm/_ffi/_ctypes/function.py”, line 55, in cfun
rv = local_pyfunc(*pyargs)
File “/home/ai/solomon/workspace/code/tvm/python/tvm/relay/op/nn/_nn.py”, line 337, in compute_contrib_conv2d_NCHWc
data_layout, out_layout, out_dtype)
File “”, line 2, in conv2d_NCHWc
File “/home/ai/solomon/workspace/code/tvm/python/tvm/target.py”, line 356, in dispatch_func
return dispatch_dict[k](*args, **kwargs)
File “”, line 2, in config_dispatcher
File “/home/ai/solomon/workspace/code/tvm/python/tvm/autotvm/task/dispatcher.py”, line 199, in dispatch_func
return dispatch_dict[‘direct’](cfg, *args, **kwargs)
File “/home/ai/solomon/workspace/code/tvm/python/tvm/autotvm/task/topi_integration.py”, line 267, in template_call
node = f(cfg, *args, **kwargs)
File “/home/ai/solomon/workspace/code/tvm/topi/python/topi/x86/conv2d.py”, line 377, in _declaration_conv_NCHWc
oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape)
ValueError: not enough values to unpack (expected 7, got 6)

I also saw the same issue, Type resolving pass does not unify "uint8" and "int8"?
I think TVM needs some fix in the type unification part. Currently, it just does the AlphaEqual of the two input types. For your information, I unblocked myself by using the rhs type (hacky way).

  Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type);
  if (resolved.defined() == false) {
      return rhs->resolved_type;
  }
  CHECK(resolved.defined())
    << "Unable to unify parent types: "
    << lhs->resolved_type << " and " << rhs->resolved_type;
  TypeNode* top = solver_->GetTypeNode(resolved);
  solver_->MergeFromTo(lhs, top);
  solver_->MergeFromTo(rhs, top);
  return resolved;
1 Like