[VTA] A Workaround for Deploying Faster R-CNN on Target ext_dev(VTA and ARM CPU)

Hi, these days I have been working on deploying Faster R-CNN on VTA. Thanks to the versatility of schedule [TOPI] Using x86 schedules for ARM conv2d. Most llvm op strategy can be used on VTA.

But using generic argsort strategy on VTA has a small problem, I have proposed a preliminary solution at runtime. Solving this problem at compile time is beyond my ability.

Argsort Strategy Problem

It is a problem caused by memory allocate strategy. VTA use VTABufferAlloc to get a vta::DataBuffer pointer. But this pointer is misused as a data virtual address.

By printing the lower schedule, I guess somehow the pass cannot tranform the arguments in extern function. If tvm.contrib.sort.argsort_nms can use compute_ptr, valid_count_ptr, argsort_nms_cpu_ptr as arguments, this problem will be solved at complie time.

PrimFunc([data, valid_count, hybrid_nms.v1]) attrs={"tir.noalias": (bool)1, "global_symbol": "test_nms"} {
  ...
  let valid_count_ptr = VTABufferCPUPtr(tvm_thread_context(VTATLSCommandHandle()), valid_count)
  let data_ptr = VTABufferCPUPtr(tvm_thread_context(VTATLSCommandHandle()), data)
  // attr [compute] storage_scope = "global"
  allocate compute[float32 * 300]
  let compute_ptr = VTABufferCPUPtr(tvm_thread_context(VTATLSCommandHandle()), compute)
  // attr [argsort_nms_cpu] storage_scope = "global"
  allocate argsort_nms_cpu[int32 * 50]
  let argsort_nms_cpu_ptr = VTABufferCPUPtr(tvm_thread_context(VTATLSCommandHandle()), argsort_nms_cpu)
  ...
  for (j, 0, 50) {
    compute_ptr[j] = data_ptr[((j*6) + 1)]
  }
  // attr [0] extern_scope = 0
  tvm_call_packed("tvm.contrib.sort.argsort_nms", tvm_stack_make_array(compute, tvm_stack_make_shape(1, 50), 0, 2, 0f, 0), 
    tvm_stack_make_array(valid_count, tvm_stack_make_shape(1), 0, 1, 0, 0), 
    tvm_stack_make_array(argsort_nms_cpu, tvm_stack_make_shape(1, 50), 0, 2, 0, 0), 1, (bool)0)
  ...

My Workaround

It seems that using VTABufferCPUPtr to transform the DataBuffer pointer into data virtual address would solve this problem at runtime. I just choose an simpler method, using reinterpret_cast to get the data virtual address.

TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  DLTensor *input = args[0];
  DLTensor *sort_num = args[1];
  DLTensor *output = args[2];
  int32_t axis = args[3];
  bool is_ascend = args[4];
  
  auto dtype = input->dtype;
  auto data_ptr_tmp = static_cast<int32_t *>(input->data);
  auto data_ptr = reinterpret_cast<float *>(*data_ptr_tmp);
  auto sort_num_ptr_tmp = static_cast<int32_t *>(sort_num->data);
  auto sort_num_ptr = reinterpret_cast<int32_t *>(*sort_num_ptr_tmp);
  auto output_data_ptr_tmp = static_cast<int32_t *>(output->data);
  auto output_data_ptr = reinterpret_cast<int32_t *>(*output_data_ptr_tmp);
  ...

Beacuse NMS Strategy use this extern function, I can run NMS on VTA now. Here is my code.

from __future__ import absolute_import, print_function

import os
import time
import numpy as np
import vta
import tvm
import topi
from tvm import te
from tvm import rpc, autotvm, relay
from vta.testing import simulator
assert tvm.runtime.enabled("rpc")

env = vta.get_env()
# Set ``device=arm_cpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA.
device = "vta"
target = env.target if device == "vta" else env.target_vta_cpu
if env.TARGET not in ["sim", "tsim"]:
    # Get remote from tracker node if environment variable is set.
    # To set up the tracker, you'll need to follow the "Auto-tuning
    # a convolutional network for VTA" tutorial.
    tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
    tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
    # Otherwise if you have a device you want to program directly from
    # the host, make sure you've set the variables below to the IP of
    # your board.
    device_host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99")
    device_port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091")
    if not tracker_host or not tracker_port:
        remote = rpc.connect(device_host, int(device_port))
    else:
        remote = autotvm.measure.request_remote(env.TARGET,
                                                tracker_host,
                                                int(tracker_port),
                                                timeout=10000)
    # Reconfigure the JIT runtime and FPGA.
    # You can program the FPGA with your own custom bitstream
    # by passing the path to the bitstream file instead of None.
    reconfig_start = time.time()
    vta.reconfig_runtime(remote)
    vta.program_fpga(remote, bitstream=None)
    reconfig_time = time.time() - reconfig_start
    print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))

# In simulation mode, host the RPC server locally.
else:
    remote = rpc.LocalSession()

# Get execution context from remote
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)

dshape = (1, 50, 6)
data = te.placeholder(dshape, name="data")
valid_count = te.placeholder((dshape[0],), dtype="int32", name="valid_count")
iou_threshold = 0.7
force_suppress = True
top_k = -1
out = topi.vision.nms.non_max_suppression(data, valid_count, iou_threshold=iou_threshold,
                                force_suppress=force_suppress, top_k=top_k)
np_data = np.random.random(dshape).astype(np.float32)
np_valid_count = np.array([8]).astype(np.int32)
s = topi.generic.schedule_nms(out)

with vta.build_config(disabled_pass={"AlterOpLayout"}):
    m = tvm.lower(s, [data, valid_count, out], name="test_nms")
    print(m)
    f = tvm.build(m,target = target,target_host=env.target_host)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape[:2], dtype=np.int32), ctx)
f(tvm_data, tvm_valid_count, tvm_out)
print(tvm_out)

@thierry @FrozenGene

1 Like

The pointer should be 64 bits on my virtual machine. After correcting it, I have deployed the Faster R-CNN on target ext_dev(VTA and ARM CPU). Actually, it is only a begining for me to try more interesting things on VTA. To complish the deploying, I need do a lot of work in graph pack to achieve hardware acceleration of most convolutional layers. Hope getting your suggestions.

  auto data_ptr_tmp = static_cast<int64_t *>(input->data);
  auto data_ptr = reinterpret_cast<float *>(*data_ptr_tmp);
  auto sort_num_ptr_tmp = static_cast<int64_t *>(sort_num->data);
  auto sort_num_ptr = reinterpret_cast<int32_t *>(*sort_num_ptr_tmp);
  auto output_data_ptr_tmp = static_cast<int64_t *>(output->data);
  auto output_data_ptr = reinterpret_cast<int32_t *>(*output_data_ptr_tmp);

At present, there is few problem with quantization. The following work is to modify the graph pack function to transform most convolutions into NCHW1n16c to get accelerating. I need to add some op names to complete AST traverse in graph pack function. If there is a mistake, please correct me.

How is the “ lower schedule” printed out?

tvm.lower python api, you need to give the schedule and input/output symbol.

 print(tvm.lower(s, [data, valid_count, out], name="test_nms"))
1 Like

@thierry I have accelerated the 42-layer convolution on vta. I choose faster_rcnn_resnet50_v1b_voc mxnet model which has 56-layer convolution. I am going to work with my partner to do some optimization @ffffc

2 Likes

@hht, would you mind share your implementation of RNN to VTA?

https://drive.google.com/open?id=1io_uQjG9am5mYbFQ-c7h9nH07fYLmVnq

Here is my project including .so files. You can unzip it and run the fasterRCNN_vta.py directly with fsim for vta. I didn’t make a git commit because there are some compatibility problems in my code. You can use git status to review my changes. I am going to refactor my code later.

@jinchenglee @acapone13 @Augusto
I am sorry the shared link cannot be used. I have accelerated the 52-layer convolution on vta in my github dev branch. https://github.com/i24361/incubator-tvm

The consistency problem of vta in zcu 104 platform has been proved to be an internal logic bug in vta according to [RFC][VTA]A HLS C VTA bug

Due to the characteristics of BRAM, the fallback schedule for vta conv2d causes fault results in real FPGA. There is two way to solve this problem, one is auto-tuning, the other is construct a by-pass for VTA.

1 Like

Hi, I was wondering how to run your code on a PYNQ board?

Since I am new to VTA, I do not know much about these and would like to run your code as a primer.

Hi MengboZ, I am sorry, I don’t currently do this, and there is no Pynq board on my hand. It is recommended to see the official document.

@hht Hi, I was wondering if you implemented nn. upsample in the graph_pack process. At present, I am trying to implement Unet through VTA, but I met some problems in the graph_pack process. I suspect the reason is nn. upsample or torch.cat. A full description can be found in another post I wrote in Can Upsample be implemented on VTA in graph_pack?. I’m sorry to bother you now. Since I just got in touch with TVM and VTA, I have encountered this problem, but it has not been solved after a long time of trying. I would like to ask if you have encountered this problem or can you provide some suggestions? I look forward to receiving your suggestions. Thank you!

@MengboZ Hi, I didn’t implement nn.upsample in the graph_pack process. Perhaps you can implement nn.upsample‘s vta compute funtion and schedule function in relay strategy. If you don’t want to implement nn.upsample for vta, you can change NCHWnc to NCHW before nn.upsample and change NCHW to NCHWnc after nn.upsample. In that way, it will use arm nn.upsample instead of vta nn.upsample.