Relay opencl inference result error

I tried testing tvm ssd model with relay on rpc app. The remote.cpu device is correct detection result, but the remote.cl is error detection result. The output is [[-1, -1, …]] tensor, as:

[[[-1. -1. -1. -1. -1. -1.]

[-1. -1. -1. -1. -1. -1.] [-1. -1. -1. -1. -1. -1.] …

[-1. -1. -1. -1. -1. -1.] [-1. -1. -1. -1. -1. -1.] [-1. -1. -1. -1. -1. -1.]]]

is there a bug of opencl kernel sourcecode? The test code below is modified from https://github.com/apivovarov/mxnet-ssd-tvm, the model can be downloaded.

import os
import tvm
import mxnet as mx
import numpy as np
import time
from PIL import Image

from tvm.contrib import download, util
from tvm.contrib import ndk
from tvm.contrib.download import download
from mxnet.model import load_checkpoint
org_img = Image.open('./dog.jpg')
org_img = org_img.resize((512, 512))
img = np.asarray(org_img).astype(np.float32).copy()
img = img.transpose(2,0,1)
img /= 255.0
img = img[np.newaxis,:]
dshape = img.shape
shape_dict = {'data': img.shape}
dtype = "float32"

import cv2
test_image_path = "./mxnet-ssd-tvm/dog.jpg"
image = cv2.imread(test_image_path)
img_data = cv2.resize(image, (dshape[2], dshape[3]))
img_data = img_data[:, :, (2, 1, 0)].astype(np.float32)
img_data -= np.array([123, 117, 104])
img_data = np.transpose(np.array(img_data), (2, 0, 1))
img_data = np.expand_dims(img_data, axis=0)


######################################################################
# Convert and compile model with NNVM or Relay for CPU.

from tvm import rpc
tracker_host = "0.0.0.0"
tracker_port = 6007
key = "android"

tracker = rpc.connect_tracker(tracker_host, tracker_port)
remote = tracker.request(key, priority=0,session_timeout=60)
arch = "arm64-v8a"

##opencl
# ctx = remote.cl(0)
# target="opencl"
# target_host  = 'llvm -target=%s-linux-android' % arch

##arm_cpu
ctx = remote.cpu(0)
target = 'llvm -target=%s-linux-android' % arch #-device=arm_cpu 
target_host = None

temp = util.tempdir()

inf_json = "deploy_ssd_mobilenet_512/deploy_ssd_mobilenet_512-symbol.json"

print("mx.sym.load: " + inf_json)
sym = mx.sym.load(inf_json)
checkp = "deploy_ssd_mobilenet_512/deploy_ssd_mobilenet_512"

print("load_checkpoint: " + checkp)
_, arg_params, aux_params = load_checkpoint(checkp, 0)

import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
    "-f", "--frontend",
    help="Frontend for compilation, nnvm or relay",
    type=str,
    default="relay")
args = parser.parse_args()
high_version= False
if tvm.__version__ == "0.6.0":
    args.frontend = "nnvm"
    
else:
    args.frontend = "relay"
    high_version = True


if high_version:
    import tvm
    from tvm import relay
    net, params = relay.frontend.from_mxnet(sym, {"data": dshape}, arg_params=arg_params, \
                                            aux_params=aux_params)
    with relay.build_config(opt_level=3):
        graph, lib, params = relay.build_module.build(net, target, params=params, target_host=target_host)
else:
    import nnvm
    from nnvm import compiler
    from nnvm.frontend import from_mxnet
    net, params = from_mxnet(sym, arg_params, aux_params)
    with compiler.build_config(opt_level=3):
        graph, lib, params = compiler.build(
            net, target, {"data": dshape}, params=params, target_host=target_host)

lib_file = temp.relpath("model.so")
graph_file = temp.relpath('model.json')
params_file = temp.relpath('model.params')

lib.export_library(lib_file,  ndk.create_shared)
with open(graph_file, "w") as fo:
    if high_version:
        fo.write(graph)
    else:
        fo.write(graph.json())
with open(params_file, "wb") as fo:
    if high_version:
        fo.write(relay.save_param_dict(params))
    else:
        fo.write(nnvm.compiler.save_param_dict(params))

#upload model
print('Run %s test ...'%(ctx))

remote.upload(lib_file)
remote.upload(graph_file)
remote.upload(params_file)


class_names = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair",
               "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant",
               "sheep", "sofa", "train", "tvmonitor"]

###
lib = remote.load_module("model.so")
graph = open(graph_file).read()
params = bytearray(open(params_file, "rb").read())

# load parameters

mod = tvm.contrib.graph_runtime.create(graph, lib,ctx)
mod.load_params(params)
###
input_data = tvm.nd.array(img_data.astype(dtype))

#data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype('float32'))
mod.set_input('data', input_data)
# execute
current_milli_time = lambda: int(round(time.time() * 1000))
# get outputs
mod.run(data = input_data)
tvm_output = mod.get_output(0)

print(tvm_output)

print("Evaluate inference time cost...")
ftimer = mod.module.time_evaluator("run", ctx, number=1, repeat=1)
prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
print("time %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)))
# logging.info("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)))

out = tvm_output.asnumpy()[0]

i = 0
for det in out:
    cid = int(det[0])
    if cid < 0:
        continue
    score = det[1]
    if score < 0.5:
         continue
    i += 1

    print(i, class_names[cid], det)

######################################################################
# Display result

def display(img, out, thresh=0.5):
    import random
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    mpl.rcParams['figure.figsize'] = (10, 10)
    pens = dict()
    plt.clf()
    plt.imshow(img)
    for det in out:
        cid = int(det[0])
        if cid < 0:
            continue
        score = det[1]
        if score < thresh:
            continue
        if cid not in pens:
            pens[cid] = (random.random(), random.random(), random.random())
        scales = [img.shape[1], img.shape[0]] * 2
        xmin, ymin, xmax, ymax = [int(p * s) for p, s in zip(det[2:6].tolist(), scales)]
        rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False,
                             edgecolor=pens[cid], linewidth=3)
        plt.gca().add_patch(rect)
        text = class_names[cid]
        plt.gca().text(xmin, ymin-2, '{:s} {:.3f}'.format(text, score),
                       bbox=dict(facecolor=pens[cid], alpha=0.5),
                       fontsize=12, color='white')
    plt.savefig('myres.jpg')
    # plt.show()

image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
display(image, tvm_output.asnumpy()[0], thresh=0.45)

The tvm version: 0.7.dev0

I have tried the nnvm api with remote.cl,the detection result is the same wrong as relay.

thanks for your help!

I’ve tried your script with NVIDIA OpenCL, but it looks working correctly.

Run remote[0]:opencl(0) test ...
[[[ 6.          0.9542923   0.60183966  0.13442558  0.8955139
    0.29527432]
  [-1.         -1.          0.60981846  0.13781077  0.9041296
    0.29533324]
  [11.          0.8844297   0.14165536  0.38350025  0.4080676
    0.9418137 ]
  ...
  [-1.         -1.         -1.         -1.         -1.
   -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.        ]
  [-1.         -1.         -1.         -1.         -1.
   -1.        ]]]
Evaluate inference time cost...
time 9.89 ms (0.00 ms)
1 car [6.         0.9542923  0.60183966 0.13442558 0.8955139  0.29527432]
2 dog [11.          0.8844297   0.14165536  0.38350025  0.4080676   0.9418137 ]

Perhaps, it’s a problem specific to your OpenCL platform environment?

Hi @kazum, Thanks for your reply. I have tried tested the model, perhaps a model bug. I have tested on a new model from model_zoo, which was exported to *.json and *.params with

net = model_zoo.get_model(model_name,pretrained=True)
net.hybridize()
x = mx.nd.random.normal(shape=(1, 3, 512, 512))
out1 = net(x)
net.export('net1', epoch=1)

Then I use:

sys, arg_params, aux_params = load_checkpoint(checkp, 0)
mod, params = relay.frontend.from_mxnet(sys, {"data": dshape}, arg_params=arg_params, aux_params=aux_params)

How can I load the gluon *.params(which was export ) “ONLY” to relay model directly, instead of transfer to *.json and *.params ? Thanks a lot!

@zchuang11 I’m not sure if I understood you correctly, but you can directly pass gluon HybridBlock to from_mxnet without exporting json and params. There are some examples in the tvm tutorial directory.

For example:

@kazum Ok, thank you for your patience reply!!

@kazum I have met a new problem, can you help me to fix? https://discuss.tvm.ai/t/op-equal-is-not-supported/6303