Error "direct host side access to device memory is detected in ... , did you forget to bind?" when compile a onnx model with target cuda

Hi, when I load a model with onnx format and compile it with target cuda, then error ‘Direct host side access to device memory is detected in fuse_reshape_broadcast_mul_conv2d_broadcast_mul_broadcast_add_elemwise_add. Did you forget to bind’ came out.

Traceback (most recent call last):
  File "from_onnx.py", line 84, in <module>
    graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, params=params)
  File "/home/wyq/tvm/nnvm/python/nnvm/compiler/build_module.py", line 307, in build
    graph = graph.apply("GraphCompile")
  File "/home/wyq/tvm/nnvm/python/nnvm/graph.py", line 234, in apply
    check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
  File "/home/wyq/tvm/nnvm/python/nnvm/_base.py", line 75, in check_call
    raise NNVMError(py_str(_LIB.NNGetLastError()))
nnvm._base.NNVMError: TVMCall CFunc Error:
Traceback (most recent call last):
  File "/home/wyq/tvm/python/tvm/_ffi/_ctypes/function.py", line 54, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/wyq/tvm/nnvm/python/nnvm/compiler/build_module.py", line 124, in _build
    return tvm.build(funcs, target=target, target_host=target_host)
  File "/home/wyq/tvm/python/tvm/build_module.py", line 462, in build
    "Did you forget to bind?" % func.name)
ValueError: Direct host side access to device memory is detected in fuse_reshape_broadcast_mul_conv2d_broadcast_mul_broadcast_add_elemwise_add.
 Did you forget to bind?

So how to fix it?

Can you share your model?

And tell me your tvm commit hash tag ( output of git log). It might be due to the recent change I made to NNVM operator fusion. (fusing reshape with conv2d seems fishy.)

tag: 0.4
hash tag: 60769b77f9abe29aafabda4d5d1cd625e7c61f9f
net: resnet50 with last FC layer, 512dim
onnx version: 1.2.1

onnx model: It’s not convenient to upload for the big size(230+M)

If you change this line to if(0), does it work?

I changed the code and recompiled the tvm project by your tips, but the same error “ValueError: Direct host side access to device memory is detected in fuse_reshape_broadcast_mul_conv2d_broadcast_mul_broadcast_add_elemwise_add. Did you forget to bind” happened.

ok, then my change is not related.

I can compile onnx resnet 50 model hosted at https://github.com/onnx/models/tree/master/resnet50 without problem. How did you get your onnx model?

I think if you change the target from “cuda” to “llvm”, there should be no error. Then you can save your resnet 50 to json file. Can you try this snippet, and post the output somewhere?

graph, lib, params = nnvm.compiler.build(sym, "llvm", shape_dict, params=params)
print(graph.json())

yes, no error when change the target from “cuda” to “llvm”. The output of the snippet is here https://github.com/yqwang/tvm_llvm_json/blob/master/resnet50_llvm_graph.json

I was training a model of resnet50 in pytorch and store the checkpoint as onnx format.

ok thanks. I’ll take a look.

Can you share pytorch code for resnet50, or did you use one from torchvision?

sorry for not mention that there is seblock in resnet50. here is pytorch code for resnet50 https://github.com/yqwang/tvm_llvm_json/blob/master/resnet.py

hmm, never heard of SELayer. I guess this is what is causing the error, because this is not a standard layer.
TVM and NNVM are tested mostly on standard imagenet models. If you try something new, weird errors might arise.

Is SELayer the same as this one?

yes,

import torch
import torch.nn as nn
import torch.nn.functional as F


class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = F.avg_pool2d(x, kernel_size=x.size()[2:]).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

Ok, I can reproduce your error. I think this is an interesting case.

Let me dig into a bit.

I figured out the cause of your error. If you change this line to

if tag.is_injective(OP.tag):

it should work.

Thanks, it looks well now, the problem is sovled, thanks again.

Hi @masahi

I am facing the similar error while compiling the onnx model
Note : The Error does’t come when using opt_level=0 while compiling the model using NNVM,
but get it only while using opt_level=1 or opt_level=2 or opt_level=3
Code:
with nnvm.compiler.build_config(opt_level=1):
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype_dict, params=params)
Error:
Traceback (most recent call last):
File “sface_nchw_trail1.py”, line 64, in
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype_dict, params=params)
File “/home/ubuntu/tvm_opencl/tvm/nnvm/python/nnvm/compiler/build_module.py”, line 306, in build
graph = graph.apply(“GraphCompile”)
File “/home/ubuntu/tvm_opencl/tvm/nnvm/python/nnvm/graph.py”, line 234, in apply
check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
File “/home/ubuntu/tvm_opencl/tvm/nnvm/python/nnvm/_base.py”, line 75, in check_call
raise NNVMError(py_str(_LIB.NNGetLastError()))
nnvm._base.NNVMError: TVMCall CFunc Error:
Traceback (most recent call last):
File “/home/ubuntu/tvm_opencl/tvm/python/tvm/_ffi/_ctypes/function.py”, line 55, in cfun
rv = local_pyfunc(*pyargs)
File “/home/ubuntu/tvm_opencl/tvm/nnvm/python/nnvm/compiler/build_module.py”, line 124, in _build
return tvm.build(funcs, target=target, target_host=target_host)
File “/home/ubuntu/tvm_opencl/tvm/python/tvm/build_module.py”, line 586, in build
fhost, mdev = _build_for_device(flist, tar, target_host)
File “/home/ubuntu/tvm_opencl/tvm/python/tvm/build_module.py”, line 415, in _build_for_device
“Did you forget to bind?” % func.name)
ValueError: Direct host side access to device memory is detected in fuse_matmul_relu. Did you forget to bind?

I don’t think we have a GPU schedule for matmul. So it tries to call CPU matmul schedule, which gives the error you are seeing.

Thanks @masahi
I tried to run on opencl as well,

The Error doesn’t come when using opt_level=0 while compiling the model using NNVM,
but get it only while using opt_level=1 or opt_level=2 or opt_level=3

Hi @masahi

Is there any plan for writing the GPU schedule for matmal.
Is there any way I can work on that?? If so could you please guide me for the same?