[Relay][concatenate]Downcast from relay.RefType to relay.TensorType failed."

I try to register a api to implement the gradient of relay.concatenate, Below is the code I implemented. but i think there is missing some data.

The following code is the c++ code I implemented for concatenate grad

namespace tvm {
namespace relay {

//concatenate split
bool ConcatenateGradRel(const Array<Type>& types,
                        int num_inputs,
                        const Attrs& attrs,
                        const TypeReporter& reporter) {
  // `types` contains: [data, result]

  LOG(INFO) << "############### Rel Start: #################3" <<std::endl;
  LOG(INFO) << "type:  " << types <<std::endl;
  LOG(INFO) << "attrs:  " << attrs <<std::endl;

  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  if (data == nullptr) return false;
  CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
  const auto param = attrs.as<SplitAttrs>();
  CHECK(param != nullptr);
  auto axis = param->axis;
  if (axis < 0) {
    axis += data->shape.size();
  }
  CHECK_LT(axis, data->shape.size())
    << "axis should be within the input dimension range.";
  CHECK_GE(axis, 0)
    << "axis should be within the input dimension range.";

  auto indices = param->indices_or_sections.as<ArrayNode>()->data;
  auto begin = IndexExpr(make_zero(Int(32)));
  std::vector<Type> fields;
  for (unsigned int i = 0; i < indices.size(); ++i) {
    CHECK(reporter->Assert(IndexExpr(indices[i]) > begin))
        << "indices_or_sections need to be a sorted ascending list";
    std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
    oshape[axis] = IndexExpr(indices[i]) - begin;
    begin = IndexExpr(indices[i]);
    auto vec_type = TensorTypeNode::make(oshape, data->dtype);
    fields.push_back(vec_type);
  }
  CHECK(reporter->Assert(begin < data->shape[axis]))
      << "The sum of sections must match the input.shape[axis]";
  std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
  oshape[axis] = data->shape[axis] - begin;
  auto vec_type = TensorTypeNode::make(oshape, data->dtype);
  fields.push_back(vec_type);

  LOG(INFO) << "Indices fields: " << Array<Type>(fields) <<std::endl;
  LOG(INFO) << "############ Grad Rel end ########## " <<std::endl;

  reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
  return true;
}

Array<Tensor> ConcatenateGradCompute(const Attrs& attrs,
                                     const Array<Tensor>& inputs,
                                     const Type& out_type,
                                     const Target& target) {

  LOG(INFO) << "############ Grad Compute Start: ##########" <<std::endl;
  LOG(INFO) << "attrs: " << attrs <<std::endl;
  LOG(INFO) << "inputs: " << inputs <<std::endl;
  LOG(INFO) << "out type: " << out_type <<std::endl;
  LOG(INFO) << "target: " << target <<std::endl;

  const auto param = attrs.as<SplitAttrs>();
  CHECK(param != nullptr)
  auto indices = Downcast<Array<Integer> >(param->indices_or_sections);
  LOG(INFO) << "indices:" << indices <<std::endl;
  LOG(INFO) << "############ Split Compute End ##########" <<std::endl;
  return Array<Tensor>{ topi::split(inputs[0], indices, param->axis) };

}

Expr MakeConcatenateGrad(Expr data,
                         NodeRef indices,
                         int axis) {

  LOG(INFO) << "############ Make Before ##############" <<std::endl;
  LOG(INFO) << "Expr data:  " << data <<std::endl;
  LOG(INFO) << "Node ref :  " << indices <<std::endl;
  auto attrs = make_node<SplitAttrs>();
  attrs->axis = axis;
  attrs->indices_or_sections = std::move(indices);

  LOG(INFO) << "attrs:  " << Attrs(attrs) <<std::endl;
  static const Op& op = Op::Get("concatenate_grad");
  LOG(INFO) << "############ Make End ##############" <<std::endl;
  return CallNode::make(op, {data}, Attrs(attrs), {});

}


TVM_REGISTER_API("relay.op._make.concatenate_grad")
.set_body_typed(MakeConcatenateGrad);

RELAY_REGISTER_OP("concatenate_grad")
.describe(R"code(concatenate grad.

)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ConcatenateAttrs")
.set_num_inputs(1)
.add_argument("ograd", "Tensor", "The gradient of output.")
.set_support_level(1)
.add_type_rel("ConcatenateGradRel", ConcatenateGradRel)
.set_attr<FTVMCompute>("FTVMCompute", ConcatenateGradCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

}  // namespace relay
}  // namespace tvm

I use relay.concatenate connected 2 matrices. one shape is 2x3 and the other is 3x3, the concatenate axis is 0, so the result is a 5x3 matrices. At the below there is my test case code.

def test_concat():
    def verify_concat(dshapes, axis):
        y = []
        for shape in dshapes:
            y.append(relay.var("input", relay.TensorType(shape, "float32")))

        print("y", y)
        x = relay.Tuple(y)
        print("x:\n", x)
        z = relay.concatenate(x, axis=axis)
        print("z:\n", z)

        func = relay.Function(y, z)
        x_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
        print("xdata:\n", x_data)

        fwd_func = run_infer_type(func)
        print("fwd fun:\n", fwd_func)
        bwd_func = run_infer_type(gradient(fwd_func))
        print("fwd fun:\n", bwd_func)

        intrp = relay.create_executor("graph", ctx=tvm.context('llvm', 0), target="llvm")

        op_res, (op_grad,) = intrp.evaluate(bwd_func)(*x_data)
        print("op grad\n", op_grad)
    verify_concat([(2, 3), (3, 3)], 0)

i register the gradient op in python:

@register_gradient("concatenate")
def concatenate_grad(orig, grad):
    """
    Return concatenate gradient
    :param orig: attrs(axis)
    :param grad: initial gradient computed for execution result of concatenate
    :return:
    """
    from tvm.relay.expr import Tuple, Constant
    from tvm import relay
    import tvm

    axis = orig.attrs.axis
    data = orig.args[0]
   
    data_type = orig.type_args[0]
    
    indices = []
    split_indices = 0
    for i in range(len(data_type.fields) - 1):
        split_indices += data_type.fields[i].shape[axis]
        indices.append(split_indices)

    return _make.concatenate_grad(ograd, data_type, axis)

Below is the error log

TVMError: 
Error(s) have occurred. The program has been annotated with them:

In `main`: 
v0.0.3
fn (%input: Tensor[(2, 3), float32], %input1: Tensor[(3, 3), float32]) -> (Tensor[(5, 3), float32], (Tensor[(2, 3), float32], Tensor[(3, 3), float32])) {
  %0 = fn () -> () {
    ()
  };
  let %x = ref(%0);
  %1 = zeros_like(%input);
  %2 = ref(%1);
  let %x1 = (%input, %2);
  %3 = zeros_like(%input1);
  %4 = ref(%3);
  let %x2 = (%input1, %4);
  %12 = fn (%input2: (Tensor[(2, 3), float32], ref(Tensor[(2, 3), float32])), %input3: (Tensor[(3, 3), float32], ref(Tensor[(3, 3), float32]))) -> (Tensor[(5, 3), float32], ref(Tensor[(5, 3), float32])) {
    let %x4 = (%input2, %input3);
    %5 = %x4.0;
    let %x5 = concatenate(%5) an internal invariant was violated while typechecking your program [09:37:37] /home/rui.huang/tvm0805/include/tvm/node/node.h:285: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed.
; an internal invariant was violated while typechecking your program [09:37:37] /home/rui.huang/tvm0805/include/tvm/node/node.h:285: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed.
; an internal invariant was violated while typechecking your program [09:37:37] /home/rui.huang/tvm0805/include/tvm/node/node.h:285: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed.
; ;
    %6 = zeros_like(%x5);
    let %x6 = ref(%6);
    let %x7 = %x^;
    %11 = fn () -> () {
      let %x9 = %x6^;
      %7 = %x4.1 unable to unify: `(Tensor[(3, 3), float32], ref(Tensor[(3, 3), float32]))` and `ref(meta[relay.IncompleteType][0])
// meta data omitted. you can use show_meta_data=True to include meta data`; ;
      let %x10 = %7^;
      %8 = %x4.1 unable to unify: `(Tensor[(3, 3), float32], ref(Tensor[(3, 3), float32]))` and `ref(meta[relay.IncompleteType][0])
// meta data omitted. you can use show_meta_data=True to include meta data`; ;
      %9 = concatenate_grad(%x9, %x4);
      %10 = add(%x10, %9);
      let %x11 = (%8 := %10);
      %x7()
    };
    let %x8 = (%x := %11);
    (%x5, %x6)
  };
  let %x3 = %12(%x1, %x2);
  %13 = %x3.1;
  %14 = %x3.0;
  %15 = ones_like(%14);
  let %x12 = (%13 := %15);
  %16 = %x^;
  let %x13 = %16();
  %17 = %x3.0;
  %18 = %x1.1;
  %19 = %18^;
  %20 = %x2.1;
  %21 = %20^;
  %22 = (%19, %21);
  (%17, %22)
}

i pulled the newest code, and i can get the data now.
But it still has the same problem,
" Downcast from relay.RefType to relay.TensorType failed."

Perhaps see this: [RELAY]Downcast from relay.IncompleteType to relay.TensorType failed

There are some differences. I try to implement the gradient of relay.concatenate, but it seems that TVM don’t know how to downcast Tuple type data. The error log shows that: " Downcast from relay.RefType to relay.TensorType failed."

This is a IR of the relay.concatenate forward

fn (%input: Tensor[(2, 3), float32], %input1: Tensor[(3, 3), float32]) -> Tensor[(5, 3), float32] {
  %0 = (%input, %input1);
  concatenate(%0) /* ty=Tensor[(5, 3), float32] */
}

After calling the gradient(model = ‘higher_order’), this is the backward function:

TVMError: 
Error(s) have occurred. The program has been annotated with them:

In `main`: 
v0.0.3
fn (%input: Tensor[(2, 3), float32], %input1: Tensor[(3, 3), float32]) -> (Tensor[(5, 3), float32], (Tensor[(2, 3), float32], Tensor[(3, 3), float32])) {
  %0 = fn () -> () {
    ()
  };
  let %x = ref(%0);
  %1 = zeros_like(%input);
  %2 = ref(%1);
  let %x1 = (%input, %2);
  %3 = zeros_like(%input1);
  %4 = ref(%3);
  let %x2 = (%input1, %4);
  %12 = fn (%input2: (Tensor[(2, 3), float32], ref(Tensor[(2, 3), float32])), %input3: (Tensor[(3, 3), float32], ref(Tensor[(3, 3), float32]))) -> (Tensor[(5, 3), float32], ref(Tensor[(5, 3), float32])) {
    let %x4 = (%input2, %input3);
    %5 = %x4.0;
    let %x5 = concatenate(%5) an internal invariant was violated while typechecking your program [14:46:32] /home/rui.huang/tvm0805/include/tvm/node/node.h:285: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed.
; an internal invariant was violated while typechecking your program [14:46:32] /home/rui.huang/tvm0805/include/tvm/node/node.h:285: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed.
; an internal invariant was violated while typechecking your program [14:46:32] /home/rui.huang/tvm0805/include/tvm/node/node.h:285: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed.
; ;
    %6 = zeros_like(%x5);
    let %x6 = ref(%6);
    let %x7 = %x^;
    %11 = fn () -> () {
      let %x9 = %x6^;
      %7 = %x4.1 unable to unify: `(Tensor[(3, 3), float32], ref(Tensor[(3, 3), float32]))` and `ref(meta[relay.IncompleteType][0])
// meta data omitted. you can use show_meta_data=True to include meta data`; ;
      let %x10 = %7^;
      %8 = %x4.1 unable to unify: `(Tensor[(3, 3), float32], ref(Tensor[(3, 3), float32]))` and `ref(meta[relay.IncompleteType][0])
// meta data omitted. you can use show_meta_data=True to include meta data`; ;
      %9 = concatenate_grad(%x9, meta[relay.attrs.SplitAttrs][0]);
      %10 = add(%x10, %9);
      let %x11 = (%8 := %10);
      %x7()
    };
    let %x8 = (%x := %11);
    (%x5, %x6)
  };
  let %x3 = %12(%x1, %x2);
  %13 = %x3.1;
  %14 = %x3.0;
  %15 = ones_like(%14);
  let %x12 = (%13 := %15);
  %16 = %x^;
  let %x13 = %16();
  %17 = %x3.0;
  %18 = %x1.1;
  %19 = %18^;
  %20 = %x2.1;
  %21 = %20^;
  %22 = (%19, %21);
  (%17, %22)
}
// meta data omitted. you can use show_meta_data=True to include meta data
Process finished with exit code 1

We can see from the log above that the backward changes the data type of the input data:

  %12 = fn (%input2: (Tensor[(2, 3), float32], ref(Tensor[(2, 3), float32])), %input3: (Tensor[(3, 3), float32], ref(Tensor[(3, 3), float32]))) -> (Tensor[(5, 3), float32], ref(Tensor[(5, 3), float32])) {
    let %x4 = (%input2, %input3);
    %5 = %x4.0;
    let %x5 = concatenate(%5) an internal invariant was violated while typechecking your program [14:46:32] /home/rui.huang/tvm0805/include/tvm/node/node.h:285: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed.
; an internal invariant was violated while typechecking your program [14:46:32] /home/rui.huang/tvm0805/include/tvm/node/node.h:285: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed.
; an internal invariant was violated while typechecking your program [14:46:32] /home/rui.huang/tvm0805/include/tvm/node/node.h:285: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed.
; 

What the forward data type is a Tensor type of '%input: Tensor[(2, 3), float32], %input1: Tensor[(3, 3), float32])', why the backward changes it to %input2: (Tensor[(2, 3), float32], ref(Tensor[(2, 3), float32])), %input3: (Tensor[(3, 3), float32], ref(Tensor[(3, 3), float32]))

@MarisaKirisame please give me some suggestions

@Ruinhuang I fixed it in add_grad, will upstream rn.

@Ruinhuang see https://github.com/dmlc/tvm/pull/3729.

Awesome!!!
Thanks very much!!!:+1: