I try to register a api to implement the gradient of relay.concatenate, Below is the code I implemented. but i think there is missing some data.
The following code is the c++ code I implemented for concatenate grad
namespace tvm {
namespace relay {
//concatenate split
bool ConcatenateGradRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
LOG(INFO) << "############### Rel Start: #################3" <<std::endl;
LOG(INFO) << "type: " << types <<std::endl;
LOG(INFO) << "attrs: " << attrs <<std::endl;
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
const auto param = attrs.as<SplitAttrs>();
CHECK(param != nullptr);
auto axis = param->axis;
if (axis < 0) {
axis += data->shape.size();
}
CHECK_LT(axis, data->shape.size())
<< "axis should be within the input dimension range.";
CHECK_GE(axis, 0)
<< "axis should be within the input dimension range.";
auto indices = param->indices_or_sections.as<ArrayNode>()->data;
auto begin = IndexExpr(make_zero(Int(32)));
std::vector<Type> fields;
for (unsigned int i = 0; i < indices.size(); ++i) {
CHECK(reporter->Assert(IndexExpr(indices[i]) > begin))
<< "indices_or_sections need to be a sorted ascending list";
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] = IndexExpr(indices[i]) - begin;
begin = IndexExpr(indices[i]);
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
}
CHECK(reporter->Assert(begin < data->shape[axis]))
<< "The sum of sections must match the input.shape[axis]";
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] = data->shape[axis] - begin;
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
LOG(INFO) << "Indices fields: " << Array<Type>(fields) <<std::endl;
LOG(INFO) << "############ Grad Rel end ########## " <<std::endl;
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
return true;
}
Array<Tensor> ConcatenateGradCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
LOG(INFO) << "############ Grad Compute Start: ##########" <<std::endl;
LOG(INFO) << "attrs: " << attrs <<std::endl;
LOG(INFO) << "inputs: " << inputs <<std::endl;
LOG(INFO) << "out type: " << out_type <<std::endl;
LOG(INFO) << "target: " << target <<std::endl;
const auto param = attrs.as<SplitAttrs>();
CHECK(param != nullptr)
auto indices = Downcast<Array<Integer> >(param->indices_or_sections);
LOG(INFO) << "indices:" << indices <<std::endl;
LOG(INFO) << "############ Split Compute End ##########" <<std::endl;
return Array<Tensor>{ topi::split(inputs[0], indices, param->axis) };
}
Expr MakeConcatenateGrad(Expr data,
NodeRef indices,
int axis) {
LOG(INFO) << "############ Make Before ##############" <<std::endl;
LOG(INFO) << "Expr data: " << data <<std::endl;
LOG(INFO) << "Node ref : " << indices <<std::endl;
auto attrs = make_node<SplitAttrs>();
attrs->axis = axis;
attrs->indices_or_sections = std::move(indices);
LOG(INFO) << "attrs: " << Attrs(attrs) <<std::endl;
static const Op& op = Op::Get("concatenate_grad");
LOG(INFO) << "############ Make End ##############" <<std::endl;
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.concatenate_grad")
.set_body_typed(MakeConcatenateGrad);
RELAY_REGISTER_OP("concatenate_grad")
.describe(R"code(concatenate grad.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ConcatenateAttrs")
.set_num_inputs(1)
.add_argument("ograd", "Tensor", "The gradient of output.")
.set_support_level(1)
.add_type_rel("ConcatenateGradRel", ConcatenateGradRel)
.set_attr<FTVMCompute>("FTVMCompute", ConcatenateGradCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay
} // namespace tvm
I use relay.concatenate connected 2 matrices. one shape is 2x3 and the other is 3x3, the concatenate axis is 0, so the result is a 5x3 matrices. At the below there is my test case code.
def test_concat():
def verify_concat(dshapes, axis):
y = []
for shape in dshapes:
y.append(relay.var("input", relay.TensorType(shape, "float32")))
print("y", y)
x = relay.Tuple(y)
print("x:\n", x)
z = relay.concatenate(x, axis=axis)
print("z:\n", z)
func = relay.Function(y, z)
x_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
print("xdata:\n", x_data)
fwd_func = run_infer_type(func)
print("fwd fun:\n", fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))
print("fwd fun:\n", bwd_func)
intrp = relay.create_executor("graph", ctx=tvm.context('llvm', 0), target="llvm")
op_res, (op_grad,) = intrp.evaluate(bwd_func)(*x_data)
print("op grad\n", op_grad)
verify_concat([(2, 3), (3, 3)], 0)
i register the gradient op in python:
@register_gradient("concatenate")
def concatenate_grad(orig, grad):
"""
Return concatenate gradient
:param orig: attrs(axis)
:param grad: initial gradient computed for execution result of concatenate
:return:
"""
from tvm.relay.expr import Tuple, Constant
from tvm import relay
import tvm
axis = orig.attrs.axis
data = orig.args[0]
data_type = orig.type_args[0]
indices = []
split_indices = 0
for i in range(len(data_type.fields) - 1):
split_indices += data_type.fields[i].shape[axis]
indices.append(split_indices)
return _make.concatenate_grad(ograd, data_type, axis)
Below is the error log
TVMError:
Error(s) have occurred. The program has been annotated with them:
In `main`:
v0.0.3
fn (%input: Tensor[(2, 3), float32], %input1: Tensor[(3, 3), float32]) -> (Tensor[(5, 3), float32], (Tensor[(2, 3), float32], Tensor[(3, 3), float32])) {
%0 = fn () -> () {
()
};
let %x = ref(%0);
%1 = zeros_like(%input);
%2 = ref(%1);
let %x1 = (%input, %2);
%3 = zeros_like(%input1);
%4 = ref(%3);
let %x2 = (%input1, %4);
%12 = fn (%input2: (Tensor[(2, 3), float32], ref(Tensor[(2, 3), float32])), %input3: (Tensor[(3, 3), float32], ref(Tensor[(3, 3), float32]))) -> (Tensor[(5, 3), float32], ref(Tensor[(5, 3), float32])) {
let %x4 = (%input2, %input3);
%5 = %x4.0;
let %x5 = concatenate(%5) an internal invariant was violated while typechecking your program [09:37:37] /home/rui.huang/tvm0805/include/tvm/node/node.h:285: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed.
; an internal invariant was violated while typechecking your program [09:37:37] /home/rui.huang/tvm0805/include/tvm/node/node.h:285: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed.
; an internal invariant was violated while typechecking your program [09:37:37] /home/rui.huang/tvm0805/include/tvm/node/node.h:285: Check failed: ref->template is_type<typename SubRef::ContainerType>() || ref->template derived_from<typename SubRef::ContainerType>(): Downcast from relay.RefType to relay.TensorType failed.
; ;
%6 = zeros_like(%x5);
let %x6 = ref(%6);
let %x7 = %x^;
%11 = fn () -> () {
let %x9 = %x6^;
%7 = %x4.1 unable to unify: `(Tensor[(3, 3), float32], ref(Tensor[(3, 3), float32]))` and `ref(meta[relay.IncompleteType][0])
// meta data omitted. you can use show_meta_data=True to include meta data`; ;
let %x10 = %7^;
%8 = %x4.1 unable to unify: `(Tensor[(3, 3), float32], ref(Tensor[(3, 3), float32]))` and `ref(meta[relay.IncompleteType][0])
// meta data omitted. you can use show_meta_data=True to include meta data`; ;
%9 = concatenate_grad(%x9, %x4);
%10 = add(%x10, %9);
let %x11 = (%8 := %10);
%x7()
};
let %x8 = (%x := %11);
(%x5, %x6)
};
let %x3 = %12(%x1, %x2);
%13 = %x3.1;
%14 = %x3.0;
%15 = ones_like(%14);
let %x12 = (%13 := %15);
%16 = %x^;
let %x13 = %16();
%17 = %x3.0;
%18 = %x1.1;
%19 = %18^;
%20 = %x2.1;
%21 = %20^;
%22 = (%19, %21);
(%17, %22)
}