[golang] tvm golang error when run conv model on cuda

The model library files are saved after auto tuning for cuda and reloaded with python normally, however when i use the go api to load models, errors come out as:

TVM Version   : v0.5.dev
DLPACK Version: v16
Global Functions:[tvm.graph_runtime.create module._GetSystemLib module._Enabled device_api.cpu module.loadfile_so module._GetTypeKey tvm.graph_runtime.remote_create module._GetSource module._LoadFromFile runtime.config_threadpool _GetDeviceAttr __tvm_set_device module._ImportsSize module._GetImport module._SaveToFile]
[09:22:23] /usr/tvm/golang/..//src/runtime/module_util.cc:35: Check failed: f != nullptr Loader of cuda(module.loadbinary_cuda) is not presented.
Stack trace returned 10 entries:
[bt] (0) ./test(dmlc::StackTrace[abi:cxx11](unsigned long)+0x9d) [0x55b3ea]
[bt] (1) ./test(dmlc::LogMessageFatal::~LogMessageFatal()+0x2f) [0x55b713]
[bt] (2) ./test(tvm::runtime::ImportModuleBlob(char const*, std::vector<tvm::runtime::Module, std::allocator<tvm::runtime::Module> >*)+0x32b) [0x548974]
[bt] (3) ./test(tvm::runtime::DSOModuleNode::Init(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)+0x8d) [0x56759d]
[bt] (4) ./test() [0x54e4ab]
[bt] (5) ./test() [0x555214]
[bt] (6) ./test(std::function<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x5a) [0x56bc50]
[bt] (7) ./test(tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::operator()<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) const+0xd0) [0x56e378]
[bt] (8) ./test(tvm::runtime::Module::LoadFromFile(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)+0x267) [0x54958d]
[bt] (9) ./test(TVMModLoadFromFile+0x95) [0x546a46]
Please copy tvm compiled modules here and update the sample.go accordingly.
You may need to update modLib, modJSON, modParams, tshapeIn, tshapeOut

Is CUDA supported here? What should i do?

And here is how i dump the library from tf models:

import tvm
from tvm import relay
import numpy as np
from tvm import autotvm
from tvm.contrib import graph_runtime, util
import tensorflow as tf

pb_file = '/data/tvm_experiment/cv_models/frozen_inception_v3.pb'
log_file = './inception_v3_1000_autotuning.log'

with tf.gfile.FastGFile(pb_file, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name='')

layout = 'NCHW'
dtype = 'float32'
shape_dict = {'input': (1,299,299,3)}
target = tvm.target.create('cuda -model=p40')
ctx = tvm.context(str(target), 0)
target_host = 'llvm'
net, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict)


with autotvm.apply_history_best(log_file):
    print("Compile...")
    with relay.build_config(opt_level=3):
        deploy_graph, lib, params = relay.build(net, target=target, target_host=target_host, params=params)

path_lib = "./img_classify_deploy.so"
lib.export_library(path_lib)
with open("./img_classify_deploy.json", "w") as fo:
    fo.write(deploy_graph)
with open("./img_classify_deploy.params", "wb") as fo:
    fo.write(relay.param_dict.save_param_dict(params))

print('check if the library can be loaded......')
from PIL import Image
img_name = './elephant-299.jpg'
image = Image.open(img_name).resize((299, 299))

x = np.array(image)
x = (x / 255.) - 0.5
x = x.reshape((1,) + x.shape)

######################################################################
# We can load the module back.
loaded_lib = tvm.module.load(path_lib)
loaded_json = open("./img_classify_deploy.json").read()
loaded_params = bytearray(open("./img_classify_deploy.params", "rb").read())
module = graph_runtime.create(loaded_json, loaded_lib, ctx)
params = relay.param_dict.load_param_dict(loaded_params)
# directly load from byte array
module.load_params(loaded_params)
module.set_input('input', tvm.nd.array(x.astype(dtype)))
module.set_input(**params)
module.run()
# get the first output
tvm_output = module.get_output(0, tvm.nd.empty(((1, 75)), 'float32'))
predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)

f = open('/data/tvm_experiment/cv_models/labels.txt', 'r')
node_lookup = {}
i = 0
for line in f.readlines():
    node_lookup[i] = line.strip()
    i += 1
f.close()
# Print top 5 predictions from TVM output.
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
    human_string = node_lookup[node_id]
    score = predictions[node_id]
    print('%s (score = %.5f)' % (human_string, score))

# evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=600)
prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
      (np.mean(prof_res), np.std(prof_res)))

to enable CUDA you may refer to below changes.

diff --git a/golang/Makefile b/golang/Makefile
index 54019740..17a8acda 100644
--- a/golang/Makefile
+++ b/golang/Makefile
@@ -7,10 +7,10 @@ NATIVE_SRC = tvm_runtime_pack.cc
 
 GOPATH=$(CURDIR)/gopath
 GOPATHDIR=${GOPATH}/src/${TARGET}/
-CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/"
+CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/ -I/usr/local/cuda/include"
 CGO_CXXFLAGS="-std=c++11"
 CGO_CFLAGS="-I${TVM_BASE}"
-CGO_LDFLAGS="-ldl -lm"
+CGO_LDFLAGS="-ldl -lm -lcuda -lcublas -lcudnn -lcudart -lnvrtc -L/usr/local/cuda-9.0/lib64/"
 
 all:
        @mkdir gopath 2>/dev/null || true
diff --git a/golang/src/tvm_runtime_pack.cc b/golang/src/tvm_runtime_pack.cc
index 718a79eb..060dabd9 100644
--- a/golang/src/tvm_runtime_pack.cc
+++ b/golang/src/tvm_runtime_pack.cc
@@ -41,8 +41,8 @@
 // #include "../../src/runtime/metal/metal_module.mm"
 
 // Uncomment the following lines to enable CUDA
-// #include "../../src/runtime/cuda/cuda_device_api.cc"
-// #include "../../src/runtime/cuda/cuda_module.cc"
+#include "src/runtime/cuda/cuda_device_api.cc"
+#include "src/runtime/cuda/cuda_module.cc"
 
 // Uncomment the following lines to enable OpenCL
 // #include "../../src/runtime/opencl/opencl_device_api.cc"

I will try to push a simplified the change to support accelerators in golang.