[Frontend][Pytorch] Convert RCNN model from Torch Vision Model

Hi,

I’m trying to convert RCNN model from torch vision model. Due to a known issue of jit tracing for RCNN: https://discuss.pytorch.org/t/fasterrcnn-resnet50-jit-trace/62337 we can’t use jit tracing for RCNN model. If I use JIT script, I got a bunch of missing ops:

NotImplementedError: The following operators are not implemented: ['aten::_set_item', 'torchvision::_new_empty_tensor_op', 'prim::dtype', 'aten::__isnot__', 'prim::max', 'torchvision::nms', 'prim::requires_grad', 'aten::index', 'aten::is_scripting', 'prim::DictConstruct', 'aten::uniform_', 'aten::tensor', 'aten::nll_loss2d', 'aten::smooth_l1_loss', 'prim::layout', 'aten::__and__', 'aten::empty', 'aten::__derive_index', 'aten::list', 'aten::__interpolate', 'aten::nonzero', 'aten::new_full', 'aten::index_put_', 'prim::CreateObject', 'aten::items', 'prim::RaiseException', 'prim::shape', 'aten::floordiv', 'aten::warn', 'aten::__or__', 'aten::__is__', 'prim::SetAttr', 'aten::item', 'aten::randperm', 'aten::__not__', 'prim::Uninitialized', 'aten::numel', 'aten::dim', 'aten::new_empty', 'aten::nll_loss', 'aten::unbind', 'aten::append', 'aten::broadcast_tensors', 'aten::is_pinned', 'aten::values', 'prim::unchecked_cast', 'aten::_cast_Float', 'aten::__range_length', 'aten::_cast_Byte', 'prim::min', 'aten::_get_tracing_state', 'aten::update', 'prim::TupleIndex', 'aten::format', 'aten::clear', 'aten::dict', 'aten::copy_', 'aten::scalar_tensor', 'aten::conv2d', 'aten::str', 'aten::__contains__', 'aten::binary_cross_entropy_with_logits', 'aten::remainder', 'aten::keys', 'aten::insert', 'aten::as_tensor', 'torchvision::roi_align', 'aten::l1_loss']

Some of them just op name difference, such as aten::conv2d.

A simple script to reproduce:

import torch
import torchvision

from tvm import relay

model_name = 'fasterrcnn_resnet50_fpn'
model = getattr(torchvision.models.detection, model_name)(pretrained=True)
model = model.eval()

input_shape = [1, 3, 512, 512]
scripted_model = torch.jit.script(model)

input_name = 'input0'
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

Any idea on this? @masahi @t-vi @alexwong

I remember that I had this script that demonstrates tracing is possible for both faster rcnn and mask rcnn. The trick is just to have a wrapper that converts dict to tuple (since tracing doesn’t support dict outputs).

import torch
import torchvision
import numpy as np
from tvm import relay


def do_script(model, in_size=100):
    model_script = torch.jit.script(model)
    model_script.eval()
    return model_script


def do_trace(model, in_size=100):
    model_trace = torch.jit.trace(model, torch.rand(1, 3, in_size, in_size))
    model_trace.eval()
    return model_trace


def dict_to_tuple(out_dict):
    if "masks" in out_dict.keys():
        return (out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"])
    return (out_dict["boxes"], out_dict["scores"], out_dict["labels"])


class TraceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inp):
        out = self.model(inp)
        return dict_to_tuple(out[0])


def save_jit_model(script=False):
    model_funcs = [torchvision.models.detection.fasterrcnn_resnet50_fpn,
                   torchvision.models.detection.maskrcnn_resnet50_fpn]

    names = ["faster_rcnn", "mask_rcnn"]

    for name, model_func in zip(names, model_funcs):
        if script:
            model = model_func(num_classes=50, pretrained_backbone=False)
        else:
            model = TraceWrapper(model_func(num_classes=50, pretrained_backbone=False))

        model.eval()
        in_size = 100
        inp = torch.rand(1, 3, in_size, in_size)

        with torch.no_grad():
            out = model(inp)

            if script:
                out = dict_to_tuple(out[0])
                script_module = do_script(model)
                script_out = script_module([inp[0]])[1]
                script_out = dict_to_tuple(script_out[0])
            else:
                script_module = do_trace(model)
                script_out = script_module(inp)

            assert len(out[0]) > 0 and len(script_out[0]) > 0

            # compare bbox coord
            print(np.max(np.abs(out[0].numpy() - script_out[0].numpy())))

            torch._C._jit_pass_inline(script_module.graph)
            torch.jit.save(script_module, name + ".pt")
1 Like

And for mask rcnn, this is the list of unimplemented op. It seems doable now, much better than last time I was trying a half year ago.

@kevinthesun The script to repro: https://github.com/masahi/torchscript-to-tvm/blob/master/maskrcnn_test.py

To convert roi_align I use the custom convert map mechanism we have in the torch frontend. Since we now have scatter op in Relay, converting aten::scatter should be straightforward. torchvision::nms seems to be the only scary op, but I hope we can convert it in a similar way as roi_align.

['aten::nonzero',
 'aten::_shape_as_tensor',
 'aten::index',
 'aten::scalar_tensor',
 'aten::unbind',
 'aten::__interpolate',
 'aten::__and__',
 'aten::scatter',
 'torchvision::nms']
1 Like

Cool! I will try this. Thx!

This works! Only a few ops are missing now.

1 Like

Are there any updates for supporting other ops? MaskRCNN Torchvision model is supported?

I am facing some attribute issue (with infinity float if I am not wrong). Does this mean all other ops are supported?

Please refer: Compiling MaskRCNN from Torchvision model zoo issues

Maskrcnn is supported. We run it on every CI job. There is even an tutorial for it https://github.com/apache/tvm/blob/main/gallery/how_to/deploy_models/deploy_object_detection_pytorch.py

1 Like

@masahi Thanks for that. Will this compilation (using relayvm) works for VITIS AI workflow too?

I don’t know what vitis AI does.