RPC failed autotuning Jetson TX2

I registered a jetson tx2 to an RPC tracker running on my desktop pc. When I start the tuning, I get the following error:

Extract tasks...
ANTLR runtime and generated code versions disagree: 4.8!=4.7.2
ANTLR runtime and generated code versions disagree: 4.8!=4.7.2
Start tuning...
[Task  1/24]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (0/2000) | 0.00 sTraceback (most recent call last):

  File "./run.py", line 75, in <module>
    tuner.autotune()

  File "/home/can/Dev/tvm_wd/tuner.py", line 177, in autotune
    autotvm.callback.log_to_file(tmp_log_file)])

  File "/home/can/Dev/tvm/python/tvm/autotvm/tuner/xgboost_tuner.py", line 90, in tune
    super(XGBTuner, self).tune(*args, **kwargs)

  File "/home/can/Dev/tvm/python/tvm/autotvm/tuner/tuner.py", line 111, in tune
    measure_batch = create_measure_batch(self.task, measure_option)

  File "/home/can/Dev/tvm/python/tvm/autotvm/measure/measure.py", line 253, in create_measure_batch
    attach_objects = runner.set_task(task)

  File "/home/can/Dev/tvm/python/tvm/autotvm/measure/measure_methods.py", line 214, in set_task
    raise RuntimeError("Cannot get remote devices from the tracker. "

RuntimeError: Cannot get remote devices from the tracker. Please check the status of tracker by 'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' and make sure you have free devices on the queue status.

Exception in thread Thread-11:
Traceback (most recent call last):
  File "/usr/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/home/can/Dev/tvm/python/tvm/autotvm/measure/measure_methods.py", line 579, in _check
    while not ctx.exist:  # wait until we get an available device
  File "/home/can/Dev/tvm/python/tvm/_ffi/runtime_ctypes.py", line 189, in exist
    self.device_type, self.device_id, 0) != 0
  File "/home/can/Dev/tvm/python/tvm/_ffi/runtime_ctypes.py", line 183, in _GetDeviceAttr
    device_type, device_id, attr_id)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 315, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 250, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 239, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 160, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (4) /home/can/Dev/tvm/build/libtvm.so(TVMFuncCall+0x65) [0x7fe1f4db43b5]
  [bt] (3) /home/can/Dev/tvm/build/libtvm.so(+0xd2a831) [0x7fe1f4db2831]
  [bt] (2) /home/can/Dev/tvm/build/libtvm.so(tvm::runtime::RPCDeviceAPI::GetAttr(DLContext, tvm::runtime::DeviceAttrKind, tvm::runtime::TVMRetValue*)+0x149) [0x7fe1f4e12ef9]
  [bt] (1) /home/can/Dev/tvm/build/libtvm.so(+0xd9decc) [0x7fe1f4e25ecc]
  [bt] (0) /home/can/Dev/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x7c) [0x7fe1f444d31c]
  File "../src/runtime/rpc/rpc_session.cc", line 903
TVMError: Check failed: code == RPCCode: :kReturn: code=4

My code is here:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: can
"""

# import cv2
# numpy and matplotlib
import tvm.relay.testing.yolo_detection
import tvm.relay.testing.darknet
import tvm
from tvm import relay
import os
from time import time

from tvm import autotvm
from tuner import TUNER


######################################################################
# Initials
# -----------------------
CALIBRATION_SAMPLES = 16
num_threads = 4
os.environ["TVM_NUM_THREADS"] = str(num_threads)

TARGET_HOST_PC = 'llvm'
TARGET_HOST_JETSON = 'llvm -target=aarch64-linux-gnu'

LOCAL_RUNNER = autotvm.measure_option(
        builder=autotvm.LocalBuilder(timeout=10),
        runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4, min_repeat_ms=150))
RPC_RUNNER = autotvm.measure_option(
        builder=autotvm.LocalBuilder(timeout=10),
        runner=autotvm.RPCRunner('tx2','0.0.0.0', 9090,
                                 number=20, repeat=3, timeout=4, min_repeat_ms=150))

args = {
    'model_name': 'plate_detector-tiny',
    'framework': 'darknet',
    'model_dir': 'models',
    'batch_size': 1,
    'width': 416,
    'height': 416,
    'channel': 3,
    'dtype': 'float32',
    'target_host':TARGET_HOST_JETSON,
    'tuner': 'xgb',
    'n_trial': 2000,
    'early_stopping': 600,
    'measure_option': RPC_RUNNER,
    'transfer_learning': True,
    'quantize': False,
    'try_winograd': True}

######################################################################
# Prepare the environment
# -----------------------
target = tvm.target.create('cuda')
tuner = TUNER(**args)
args.update({'target': target})
tuner.update(**args)
# mod, params = [0,0]
net, mod, params = tuner.get_network.from_darknet()
# mod, params = tuner.get_network.from_mxnet()
# mod, params = tuner.get_network.from_torch()
args.update({'mod':mod, 'params':params})
tuner.update(**args)



if tuner.quantize:
    tuner.quantize_model()

tuner.autotune()

print("Export library...")
tuner.export_library(mod, params, target)
# tuned_graph, tuned_lib, tuned_params = tuner.import_library()

# tuner.evaluate()

# img_path = "./66429ef6e83191b4.jpg"
# print("test the model on an image...")
# tuner.test_darknet_yolo(img_path, net, target)

And this is the tuner class:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: can
"""


import numpy as np
import matplotlib.pyplot as plt
import tvm
import sys
from time import time
import os
import multiprocessing as mp

# from ctypes import *
from tvm import relay
from tvm.relay.testing.darknet import __darknetffi__
import tvm.relay.testing.yolo_detection
import tvm.relay.testing.darknet
from tvm import autotvm
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
# from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
import torch
import torchvision
import mxnet as mx


class TUNER(object):
    
    def __init__(self, **kwargs):
        self.batch_size = 1
        self.dtype = 'float32'
        self.model_dir = 'models'
        self.transfer_learning = True
        self.quantize = False
        self.__dict__.update(kwargs)
        self.shape = (self.batch_size, self.channel, self.height, self.width)
        
        self.backup_dir = os.path.join(self.model_dir,self.framework,self.model_name)
        
        self.log_file = "./logs/"+self.framework+'/' + self.model_name + '/' + self.model_name + ".tune.log"
        if not os.path.exists("./logs/"+self.framework):
            os.mkdir("./logs/"+self.framework)
        if not os.path.exists("./logs/"+self.framework+'/' + self.model_name):
            os.mkdir("./logs/"+self.framework+'/' + self.model_name)
        if not os.path.exists(self.log_file):
            os.mknod(self.log_file)
        
        self.out_dir = "./output/" + self.framework+'/' + self.model_name + '/'
        if not os.path.exists("./output/"+self.framework):
            os.mkdir("./output/"+self.framework)
        if not os.path.exists("./output/"+self.framework+'/' + self.model_name):
            os.mkdir("./output/"+self.framework+'/' + self.model_name)
        
        mydict = self.__dict__
        self.get_network = self.get_network(**mydict)
    
    def update(self, **kwargs):
        self.__dict__.update(kwargs)
    
    class get_network():
        def __init__(self, **kwargs):
            self.__dict__.update(kwargs)
            
        def from_darknet(self):
            ''' Generate TVM Module and Parameters for Darknet models '''
            cfg_path = os.path.join(self.backup_dir, self.model_name + '.cfg')
            weights_path = os.path.join(self.backup_dir, self.model_name + '.weights')
            DARKNET_LIB = __darknetffi__.dlopen(os.path.join(self.model_dir, self.framework,'libdarknet2.0.so'))
            
            net = DARKNET_LIB.load_network(cfg_path.encode('utf-8'), weights_path.encode('utf-8'), 0)
            
            # input_shape = [self.batch_size, net.c, net.h, net.w]
            data = np.empty(self.shape, self.dtype)
            print('\nData Shape: ', data.shape, '\n')
            
            print("Converting darknet to relay functions...")
            mod, params = relay.frontend.from_darknet(net, dtype=self.dtype, shape=data.shape)
            
            return net, mod, params
            
    
    
        def from_torch(self):
            weights_path = os.path.join(self.backup_dir, self.model_name + '.pth')
            weights = torch.load(weights_path)
            model = getattr(torchvision.models, self.model_name)(pretrained=False)
            model.load_state_dict(weights)
            model = model.eval()
            
            # We grab the TorchScripted model via tracing
            input_data = torch.randn(self.shape, dtype=torch.float32)
            scripted_model = torch.jit.trace(model, input_data).eval()
            
            shape_list = [('data', self.shape)]
            mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
            
            return mod, params
            
        
        def from_mxnet(self):
            weights_path = os.path.join(self.backup_dir, self.model_name)
            data = np.random.uniform(size=self.shape).astype(self.dtype)
            shape_dict = shape_dict = {'data': data.shape}
            sym, args, auxs = mx.model.load_checkpoint(weights_path, 0)
            mod, params = relay.frontend.from_mxnet(sym, shape_dict, arg_params=args, aux_params=auxs)
            
            return mod, params
            
    
    
    def quantize_model(self):
        with relay.quantize.qconfig(calibrate_mode='global_scale', 
                                    global_scale=8.0,
                                    store_lowbit_output=False):
            self.mod = relay.quantize.quantize(self.mod, params=self.params)
        # return mod


    def autotune(self):
        print("Extract tasks...")
        tasks = autotvm.task.extract_from_program(self.mod, target=self.target,
                                                      target_host=self.target_host,
                                                      params=self.params)
        if self.try_winograd:
            for i in range(len(tasks)):
                try:  # try winograd template
                    tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
                                              tasks[i].target, tasks[i].target_host)
                    input_channel = tsk.workload[1][1][1]
                    if input_channel >= 64:
                        tasks[i] = tsk
                except Exception:
                    pass
        
        print("Start tuning...")
        log_filename = self.log_file
        tuner = self.tuner
        
        # create tmp log file
        tmp_log_file = log_filename + ".tmp"
        if not self.transfer_learning:
            if os.path.exists(tmp_log_file):
                os.remove(tmp_log_file)
        else:
            # select actual best logs
            if not os.path.exists(tmp_log_file):
                os.mknod(tmp_log_file)
    
        for i, tsk in enumerate(reversed(tasks)):
            prefix = "[Task %2d/%2d] " %(i+1, len(tasks))
    
            # create tuner
            if tuner == 'xgb' or tuner == 'xgb-rank':
                tuner_obj = XGBTuner(tsk, loss_type='rank')
            elif tuner == 'ga':
                tuner_obj = GATuner(tsk, pop_size=100)
            elif tuner == 'random':
                tuner_obj = RandomTuner(tsk)
            elif tuner == 'gridsearch':
                tuner_obj = GridSearchTuner(tsk)
            else:
                raise ValueError("Invalid tuner: " + tuner)
    
            if self.transfer_learning:
                if os.path.isfile(tmp_log_file):
                    tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
    
            # do tuning
            tuner_obj.tune(n_trial=min(self.n_trial, len(tsk.config_space)),
                           early_stopping=self.early_stopping,
                           measure_option=self.measure_option,
                           callbacks=[
                               autotvm.callback.progress_bar(self.n_trial, prefix=prefix),
                               autotvm.callback.log_to_file(tmp_log_file)])
    
        # pick best records to a cache file
        autotvm.record.pick_best(tmp_log_file, log_filename)
        os.remove(tmp_log_file)
    
    
    def build_relay(self, mod, params, target):
        with autotvm.apply_history_best(self.log_file):
            print("Compiling with the best configuration logged...")
            with relay.build_config(opt_level=3):
                tuned_graph, tuned_lib, tuned_params = relay.build_module.build(
                    mod, target=target, params=params)
        return tuned_graph, tuned_lib, tuned_params
    
    
    def export_library(self, mod, params, target):
        tuned_graph, tuned_lib, tuned_params = self.build_relay(mod, params, target)
        print("exporting tuned libraries...")
        tuned_lib.export_library(self.out_dir+self.model_name+'.so')
        with open(self.out_dir+self.model_name+'.json', 'w') as f:
            f.write(tuned_graph)
        with open(self.out_dir+self.model_name+'.params', 'wb') as f:
            f.write(relay.save_param_dict(tuned_params))
    
    
    def import_library(self):
        tuned_lib = tvm.runtime.load_module(self.out_dir+self.model_name+'.so')
        with open(self.out_dir+self.model_name+'.json', 'r') as f:
            tuned_graph = f.read()
        with open(self.out_dir+self.model_name+'.params', 'rb') as f:
            tuned_params = bytearray(f.read())
        return tuned_graph, tuned_lib, tuned_params


    def evaluate(self):
        ctx = tvm.context(str(self.target), 0)
        try:
            tuned_graph, tuned_lib, tuned_params = self.import_library()
            module = runtime.create(tuned_graph, tuned_lib, ctx)
            module.load_params(tuned_params)
        except:
            tuned_graph, tuned_lib, tuned_params = self.build_relay(self.mod, self.params, self.target)
            module = runtime.create(tuned_graph, tuned_lib, ctx)
            module.set_input(**tuned_params)
        
        
        data_tvm = tvm.nd.array((np.random.uniform(size=self.shape)).astype(self.dtype))
        module.set_input('data', data_tvm)

        # evaluate
        print("Evaluate inference time cost...")
        ftimer = module.module.time_evaluator("run", ctx, number=20, repeat=500)
        prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
              (np.mean(prof_res), np.std(prof_res)))
    
    
    def test_darknet_yolo(self, img_path, net, target):
        tuned_graph, tuned_lib, tuned_params = self.import_library()
        # load parameters
        ctx = tvm.context(str(target), 0)
        [neth, netw] = self.shape[2:]
        data = tvm.relay.testing.darknet.load_image(img_path, netw, neth)
        font_path = "arial.ttf"
        
        # load original image
        img = tvm.relay.testing.darknet.load_image_color(img_path)
        _, im_h, im_w = img.shape
        
        # create runtime module
        module = runtime.create(tuned_graph, tuned_lib, ctx)
        module.load_params(tuned_params)
        # set inputs
        tvm_input = tvm.nd.array(data.astype(self.dtype), ctx)
        module.set_input('data', tvm_input)
        # module.set_input(**params)
        
        thresh = 0.6
        nms_thresh = 0.45
        
        # execute the module
        module.run()
        # get outputs
        tvm_out = []
        for i in range(2):
            layer_out = {}
            layer_out['type'] = 'Yolo'
            # Get the yolo layer attributes (n, out_c, out_h, out_w, classes, total)
            layer_attr = module.get_output(i*4+3).asnumpy()
            layer_out['biases'] = module.get_output(i*4+2).asnumpy()
            layer_out['mask'] = module.get_output(i*4+1).asnumpy()
            out_shape = (layer_attr[0], layer_attr[1]//layer_attr[0],
                          layer_attr[2], layer_attr[3])
            layer_out['output'] = module.get_output(i*4).asnumpy().reshape(out_shape)
            layer_out['classes'] = layer_attr[4]
            tvm_out.append(layer_out)
        
        dets = tvm.relay.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, 1, tvm_out)
        last_layer = net.layers[net.n - 1]
        tvm.relay.testing.yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh)
        
        
        label_path = os.path.join(self.backup_dir, self.model_name + ".names")
        with open(label_path) as f:
            content = f.readlines()
        names = [x.strip() for x in content]
            
        tvm.relay.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes)
        plt.imshow(img.transpose(1, 2, 0))
        plt.show()