[VTA] Running Tensorflow code in VTA simulation gives `c_runtime_api.cc` error

Following this tutorial, I can’t seem to compile a very simple Tensorflow graph and run it in the VTA simulation. Sometimes it runs just fine, but most of the time it gives me the following c_runtime_api.cc error.

  • Check failed: allow_missing Device API ext_dev is not enabled.

Did I miss any important steps when importing, compiling, or running the graph module?

python3 tf_add.py
... Tensorflow messages ...
Traceback (most recent call last):
  File "tf_add.py", line 130, in <module>
    m = graph_runtime.create(graph, lib, ctx)
  File "~/tvm/python/tvm/contrib/graph_runtime.py", line 43, in create
    return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
  File "~/tvm/python/tvm/_ffi/_ctypes/function.py", line 185, in __call__
    ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
  File "~/tvm/python/tvm/_ffi/base.py", line 72, in check_call
    raise TVMError(py_str(_LIB.TVMGetLastError()))
tvm._ffi.base.TVMError: [15:03:17] ~/tvm/src/runtime/c_runtime_api.cc:69: Check failed: allow_missing Device API ext_dev is not enabled.

Stack trace returned 10 entries:
[bt] (0) ~/tvm/build/libtvm.so(+0xa045bc) [0x7f14579b15bc]
[bt] (1) ~/tvm/build/libtvm.so(+0xa048e5) [0x7f14579b18e5]
[bt] (2) ~/tvm/build/libtvm.so(+0x133a784) [0x7f14582e7784]
[bt] (3) ~/tvm/build/libtvm.so(+0x133a4d3) [0x7f14582e74d3]
[bt] (4) ~/tvm/build/libtvm.so(+0x133a2a1) [0x7f14582e72a1]
[bt] (5) ~/tvm/build/libtvm.so(tvm::runtime::DeviceAPI::Get(DLContext, bool)+0x21) [0x7f14582e471b]
[bt] (6) ~/tvm/build/libtvm.so(tvm::runtime::NDArray::Empty(std::vector<long, std::allocator<long> >, DLDataType, DLContext)+0xad) [0x7f14582d95ab]
[bt] (7) ~/tvm/build/libtvm.so(+0x135fd4b) [0x7f145830cd4b]
[bt] (8) ~/tvm/build/libtvm.so(+0x135e2d6) [0x7f145830b2d6]
[bt] (9) ~/tvm/build/libtvm.so(+0x1361b2e) [0x7f145830eb2e]


Traceback (most recent call last):
  File "_ctypes/callbacks.c", line 234, in 'calling callback function'
  File "~/tvm/python/tvm/_ffi/_ctypes/function.py", line 28, in _ctypes_free_resource
TypeError: 'NoneType' object is not callable
Traceback (most recent call last):
  File "_ctypes/callbacks.c", line 234, in 'calling callback function'
  File "~/tvm/python/tvm/_ffi/_ctypes/function.py", line 28, in _ctypes_free_resource
TypeError: 'NoneType' object is not callable




  • For completeness and to reproduce the error, I have included the code below. The code basically adds left half of the image to the right half. The image source can be set in the “User settings” section in the code. You can also set TVM = False to run the code on CPU just to check that it works just fine.
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from PIL import Image

import nnvm
import tvm
import vta
from tvm import rpc
from tvm.contrib import graph_runtime, util

# User settings ------------------------------------------------
TVM = True
data_dir = "temp/"
image_name = 'image.jpg'
# --------------------------------------------------------------

# Input image
image = Image.open(os.path.join(data_dir, image_name))
plt.imshow(image)
plt.show()

image = np.array(image)
height, width, channel = image.shape
cutoff = width // 2

img_1 = np.copy(image)
img_1[:, :cutoff, :] = 0
img_1 = img_1[np.newaxis, :, :, :]

img_2 = np.copy(image)
img_2[:, cutoff:, :] = 0
img_2 = img_2[np.newaxis, :, :, :]

x = tf.placeholder(dtype=tf.float32, shape=(1, height, width, 3), name='input')
b = tf.constant(img_2, dtype=tf.float32, shape=(1, height, width, 3), name='b')
y = tf.math.add(x, b, name='output')

# Graph
graph_name = "graph.pb"
graph_path = os.path.join(data_dir, graph_name)

with tf.Session() as sess:
    output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
                                                                    sess.graph.as_graph_def(add_shapes=True),
                                                                    ['output'])
    with tf.gfile.GFile(graph_path, "wb") as f:
        f.write(output_graph_def.SerializeToString())

    net_out = sess.run(y, feed_dict={x: img_1})
    
    plt.imshow(net_out[0, :, :, :].astype(np.uint8))
    plt.show()
    Image.fromarray(net_out[0, :, :, :].astype(np.uint8)).save('add_expected.png')

if TVM:
    # VTA settings ----------------------------------------------------
    env = vta.get_env()
    host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99")
    port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091"))

    if env.TARGET == "pynq":
        assert tvm.module.enabled("rpc")
        remote = rpc.connect(host, port)
        vta.reconfig_runtime(remote) 
        vta.program_fpga(remote)

    elif env.TARGET == "sim":
        remote = rpc.LocalSession()

    device = "vta"
    ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
    target = tvm.target.create("llvm -device={}".format(device))

    if env.TARGET == "sim":
        target_host = "llvm"
    elif env.TARGET == "pynq":
        target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
    # End of VTA settings ---------------------------------------------

    # Import graph ----------------------------------------------------
    import_file = graph_path
    with tf.gfile.GFile(import_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')

        sym, params = nnvm.frontend.from_tensorflow(graph_def)

        shape_dict = {'input': (1, height, width, 3)}
        dtype_dict = {'input': 'float32'}

        with vta.build_config():
            graph, lib, params = nnvm.compiler.build(graph=sym,
                                                     shape=shape_dict,
                                                     dtype=dtype_dict,
                                                     target=target,
                                                     params=params,
                                                     target_host=target_host)
        
        assert tvm.module.enabled("rpc")
        temp = util.tempdir()
        lib.save(temp.relpath("graphlib.o"))

        remote.upload(temp.relpath("graphlib.o"))
        lib = remote.load_module("graphlib.o")

        # Create a runtime graph
        m = graph_runtime.create(graph, lib, ctx)
        m.set_input(**params)
        m.set_input('input', tvm.nd.array(img_1.astype(np.float32)))
        m.run()
        
        tvm_output = m.get_output(0)
        net_out = tvm_output.asnumpy()[0]

        plt.imshow(net_out.astype(np.uint8))
        plt.show()
        Image.fromarray(net_out[0, :, :, :].astype(np.uint8)).save('add_{}.png'.format(env.TARGET))

else:
    # This is for a CPU test
    import_file = graph_path
    with tf.gfile.GFile(import_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        with tf.Session() as sess:
            inp = sess.graph.get_tensor_by_name('input:0')
            out = sess.graph.get_tensor_by_name('output:0')

            net_out = sess.run(out, feed_dict={inp: img_1})

            plt.imshow(net_out[0, :, :, :].astype(np.uint8))
            plt.show()
            Image.fromarray(net_out[0, :, :, :].astype(np.uint8)).save('add_cpu.png')

Thanks for bringing this up. VTA is a specialized accelerator that have certain restrictions(only work with 8bit fix point) and restricted ALU. As a result, the models we can feed to it is somewhat restricted to quantized models.

There is an ongoing effort on bringing automatic model quantizer to map a relay program to VTA compatible quantized models. Once that is checked in, we can have a path from TF models

1 Like

so what would be a ‘working’ example that can be run with the existing VTA simulator?

otherwise stated, how does the example provided need to be modified to run successfully on the existing TVM/VTA software v0.5.0?

The original inference example https://docs.tvm.ai/vta/tutorials/resnet.html#sphx-glr-vta-tutorials-resnet-py uses a quantized version of resnet and should run out of box

@tqchen @Ravenwater I just modified the TensorFlow graph to use uint8 as data type but still it doesn’t run. The same error messages appear.
In short, the step that I import the graph and compile it with nnvm looks like this:

with tf.gfile.GFile('graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name='')
    sym, params = nnvm.frontend.from_tensorflow(graph_def)  # <----------

    shape_dict = {'input': (1, height, width, 3)}
    dtype_dict = {'input': 'uint8'}
    with vta.build_config():
        graph, lib, params = nnvm.compiler.build(graph=sym,  # <---------
                                                 shape=shape_dict,
                                                 dtype=dtype_dict,
                                                 target=target,
                                                 params=params,
                                                 target_host=target_host)
    ...
    m = graph_runtime.create(graph, lib, ctx)  # <------------------------

I believe the problem is either:

  1. I didn’t call the APIs properly.
  2. It might be because the graph fed to nnvm.frontend.from_tensorflow() is not compatible. But that is questionable to me because the graph is very simple. I attached the graph picture below.

add_2

You are right that directly use uint8 may not work as well. The specificity of the accelerator requires a special combination of ops, I would suggest wait a bit and we will more details after the automatic floating point importing tool get into mainline

1 Like

@tqchen what is this ‘automatic floating point importing tool’? Is this another point solution that hard codes arithmetic types or does this properly abstract the arithmetic number systems underlying the computation? The graph structure should be independent of the arithmetic system so that we can easily map to specific number systems such as fp16, bfloat16, posit, INT8, Elias gamma, etc.

There is going to be a ton of innovation in the arithmetic domain as that is where all the performance per Watt improvements can be found. The control of how to execute the graph is common among all, but the number system is going to be the differentiation.

1 Like