Unable to get started with Metal on macOS 10.13 with Radeon Pro 560 4096 MB

Can anyone give me some pointers to get started with Metal/OpenCL in C++ with Tensorflow

I wrote the Python script below

# tvm, relay
import tvm
from tvm import te
from tvm import relay

# os and numpy
import numpy as np
import os.path

# Tensorflow imports
import tensorflow.compat.v1 as tf
#import tensorflow as tf
tf_compat_v1 = tf
tf.disable_v2_behavior()

# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing

# Base location for model related files.
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'

# Test image
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)

model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)

# Image label map
map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
map_proto_url = os.path.join(repo_base, map_proto)

# Human readable text for labels
label_map = 'imagenet_synset_to_human_label_map.txt'
label_map_url = os.path.join(repo_base, label_map)

# Target settings
# Use these commented settings to build for cuda.
#target = 'cuda'
#target_host = 'llvm'
#layout = "NCHW"
#ctx = tvm.gpu(0)
target = 'metal'
target_host = 'llvm'
layout = "NCHW"
ctx = tvm.metal(0)
from tvm.contrib.download import download_testdata

img_path = download_testdata(image_url, img_name, module='data')
model_path = download_testdata(model_url, model_name, module=['tf', 'InceptionV1'])
map_proto_path = download_testdata(map_proto_url, map_proto, module='data')
label_path = download_testdata(label_map_url, label_map, module='data')

with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
    graph_def = tf_compat_v1.GraphDef()
    graph_def.ParseFromString(f.read())
    graph = tf.import_graph_def(graph_def, name='')
    # Call the utility to import the graph definition into default graph.
    graph_def = tf_testing.ProcessGraphDefParam(graph_def)
    # Add shapes to the graph.
    with tf_compat_v1.Session() as sess:
        graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')

from PIL import Image
image = Image.open(img_path).resize((299, 299))

x = np.array(image)

shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
mod, params = relay.frontend.from_tensorflow(graph_def,
                                             layout=layout,
                                             shape=shape_dict)

print("Tensorflow protobuf imported to relay frontend.")
with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(mod,
                                     target=target,
                                     target_host=target_host,
                                     params=params)



from tvm.contrib import graph_runtime
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))

predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)

# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
                                    uid_lookup_path=label_path)

# Print top 5 predictions from TVM output.
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
    human_string = node_lookup.id_to_string(node_id)
    score = predictions[node_id]
    print('%s (score = %.5f)' % (human_string, score))


def create_graph():
    """Creates a graph from saved GraphDef file and returns a saver."""
    # Creates graph from saved graph_def.pb.
    with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
        graph_def = tf_compat_v1.GraphDef()
        graph_def.ParseFromString(f.read())
        graph = tf.import_graph_def(graph_def, name='')
        # Call the utility to import the graph definition into default graph.
        graph_def = tf_testing.ProcessGraphDefParam(graph_def)

def run_inference_on_image(image):
    """Runs inference on an image.

    Parameters
    ----------
    image: String
        Image file name.

    Returns
    -------
        Nothing
    """
    if not tf_compat_v1.gfile.Exists(image):
        tf.logging.fatal('File does not exist %s', image)
    image_data = tf_compat_v1.gfile.GFile(image, 'rb').read()

    # Creates graph from saved GraphDef.
    create_graph()

    with tf_compat_v1.Session() as sess:
        softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
        predictions = sess.run(softmax_tensor,
                               {'DecodeJpeg/contents:0': image_data})

        predictions = np.squeeze(predictions)

        # Creates node ID --> English string lookup.
        node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
                                            uid_lookup_path=label_path)

        # Print top 5 predictions from tensorflow.
        top_k = predictions.argsort()[-5:][::-1]
        print ("===== TENSORFLOW RESULTS =======")
        for node_id in top_k:
            human_string = node_lookup.id_to_string(node_id)
            score = predictions[node_id]
            print('%s (score = %.5f)' % (human_string, score))

run_inference_on_image(img_path)

Does this work for anyone on similar hardware?

I get this

(venv) kaosnew:tvm_test sam$ python metal_tf_demo.py 
WARNING:tensorflow:From /Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tensorflow_core/python/compat/v2_compat.py:88: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
Instructions for updating:
non-resource variables are not supported in the long term
File /Users/sam/.tvm_test_data/data/elephant-299.jpg exists, skip.
File /Users/sam/.tvm_test_data/tf/InceptionV1/classify_image_graph_def-with_shapes.pb exists, skip.
File /Users/sam/.tvm_test_data/data/imagenet_2012_challenge_label_map_proto.pbtxt exists, skip.
File /Users/sam/.tvm_test_data/data/imagenet_synset_to_human_label_map.txt exists, skip.
2020-04-04 14:23:50.244505: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
2020-04-04 14:23:50.426163: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-04-04 14:23:50.438430: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fb8f5115790 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-04-04 14:23:50.438450: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
WARNING:tensorflow:From /Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/testing/tf.py:95: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
WARNING:tensorflow:From /Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tensorflow_core/python/framework/graph_util_impl.py:277: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
ANTLR runtime and generated code versions disagree: 4.8!=4.7.2
ANTLR runtime and generated code versions disagree: 4.8!=4.7.2
/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/frontend/tensorflow.py:2552: UserWarning: Ignore the passed shape. Shape in graphdef will be used for operator DecodeJpeg/contents.
  "will be used for operator %s." % node.name)
/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/frontend/tensorflow.py:621: UserWarning: DecodeJpeg: It's a pass through, please handle preprocessing before input
  warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input")
WARNING:root:Attribute Tdim is ignored in relay.sym.expand_dims
WARNING:root:Attribute T is ignored in relay.sym.expand_dims
    MORE WARNINGS
    MORE WARNINGS
Tensorflow protobuf imported to relay frontend.
WARNING:autotvm:Cannot find config for target=metal, workload=('conv2d_nchw_winograd.cuda', ('TENSOR', (1, 32, 149, 149), 'float32'), ('TENSOR', (32, 32, 3, 3), 'float32'), (1, 1), (0, 0, 0, 0), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=metal, workload=('conv2d_nchw_winograd.cuda', ('TENSOR', (1, 32, 147, 147), 'float32'), ('TENSOR', (64, 32, 3, 3), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=metal, workload=('conv2d_nchw_winograd.cuda', ('TENSOR', (1, 80, 73, 73), 'float32'), ('TENSOR', (192, 80, 3, 3), 'float32'), (1, 1), (0, 0, 0, 0), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=metal, workload=('conv2d_nchw_winograd.cuda', ('TENSOR', (1, 48, 35, 35), 'float32'), ('TENSOR', (64, 48, 5, 5), 'float32'), (1, 1), (2, 2, 2, 2), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=metal, workload=('conv2d_nchw.cuda', ('TENSOR', (1, 64, 35, 35), 'float32'), ('TENSOR', (96, 64, 3, 3), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=metal, workload=('conv2d_nchw.cuda', ('TENSOR', (1, 96, 35, 35), 'float32'), ('TENSOR', (96, 96, 3, 3), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=metal, workload=('conv2d_nchw.cuda', ('TENSOR', (1, 448, 8, 8), 'float32'), ('TENSOR', (384, 448, 3, 3), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
Traceback (most recent call last):

  File "metal_tf_demo.py", line 80, in <module>
    params=params)

  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/build_module.py", line 251, in build
    graph_json, mod, params = bld_mod.build(mod, target, target_host, params)

  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/build_module.py", line 120, in build
    self._build(mod, target, target_host)

  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) 9   libtvm.dylib                        0x0000000110a9a8b9 tvm::NodeFunctor<tvm::Array<tvm::te::Tensor, void> (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*) const + 297
  [bt] (7) 8   libtvm.dylib                        0x0000000110a9c268 tvm::relay::ExprFunctor<tvm::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::InitVTable()::'lambda4'(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)::__invoke(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*) + 24
  [bt] (6) 7   libtvm.dylib                        0x0000000110a989b2 tvm::relay::ScheduleGetter::VisitExpr_(tvm::relay::CallNode const*) + 722
  [bt] (5) 6   libtvm.dylib                        0x0000000110a9797c tvm::relay::ScheduleGetter::VisitExpr(tvm::RelayExpr const&) + 252
  [bt] (4) 5   libtvm.dylib                        0x0000000110a9a602 tvm::relay::ExprFunctor<tvm::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) + 226
  [bt] (3) 4   libtvm.dylib                        0x0000000110a9a8b9 tvm::NodeFunctor<tvm::Array<tvm::te::Tensor, void> (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*) const + 297
  [bt] (2) 3   libtvm.dylib                        0x0000000110a9c268 tvm::relay::ExprFunctor<tvm::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::InitVTable()::'lambda4'(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)::__invoke(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*) + 24
  [bt] (1) 2   libtvm.dylib                        0x0000000110a98ee3 tvm::relay::ScheduleGetter::VisitExpr_(tvm::relay::CallNode const*) + 2051
  [bt] (0) 1   libtvm.dylib                        0x0000000110beee25 std::__1::__function::__func<TVMFuncCreateFromCFunc::$_2, std::__1::allocator<TVMFuncCreateFromCFunc::$_2>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 213
  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun
    rv = local_pyfunc(*pyargs)
  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/backend/compile_engine.py", line 250, in lower_call
    op, call.attrs, inputs, ret_type, target)
  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/backend/compile_engine.py", line 183, in select_implementation
    all_impls = get_valid_implementations(op, attrs, inputs, out_type, target)
  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/backend/compile_engine.py", line 124, in get_valid_implementations
    strategy = fstrategy(attrs, inputs, out_type, target)
  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/target/generic_func.py", line 45, in __call__
    return _ffi_api.GenericFuncCallFunc(self, *args)
  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
    raise get_last_ffi_error()
  [bt] (5) 6   ???                                 0x00007ffeeb3b9a50 0x0 + 140732844972624
  [bt] (4) 5   _ctypes.cpython-37m-darwin.so       0x000000010531336f ffi_call_unix64 + 79
  [bt] (3) 4   libtvm.dylib                        0x0000000110bed266 TVMFuncCall + 70
  [bt] (2) 3   libtvm.dylib                        0x00000001105c70b5 std::__1::__function::__func<tvm::$_5, std::__1::allocator<tvm::$_5>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 181
  [bt] (1) 2   libtvm.dylib                        0x00000001105c4de7 tvm::GenericFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const + 743
  [bt] (0) 1   libtvm.dylib                        0x0000000110beee25 std::__1::__function::__func<TVMFuncCreateFromCFunc::$_2, std::__1::allocator<TVMFuncCreateFromCFunc::$_2>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 213
  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun
    rv = local_pyfunc(*pyargs)
  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/op/strategy/cuda.py", line 313, in dense_strategy_cuda
    if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/runtime_ctypes.py", line 218, in compute_version
    self.device_type, self.device_id, 4)
  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/runtime_ctypes.py", line 180, in _GetDeviceAttr
    device_type, device_id, attr_id)
  File "/Users/sam/dev/github/tvm/build/numberwang/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
    raise get_last_ffi_error()
  [bt] (6) 7   ???                                 0x00007ffeeb3b8360 0x0 + 140732844966752
  [bt] (5) 6   _ctypes.cpython-37m-darwin.so       0x000000010531336f ffi_call_unix64 + 79
  [bt] (4) 5   libtvm.dylib                        0x0000000110bed266 TVMFuncCall + 70
  [bt] (3) 4   libtvm.dylib                        0x0000000110bef340 std::__1::__function::__func<$_4, std::__1::allocator<$_4>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 400
  [bt] (2) 3   libtvm.dylib                        0x0000000110bee4a4 tvm::runtime::DeviceAPIManager::GetAPI(int, bool) + 532
  [bt] (1) 2   libtvm.dylib                        0x0000000110bee735 tvm::runtime::DeviceAPIManager::GetAPI(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, bool) + 421
  [bt] (0) 1   libtvm.dylib                        0x00000001101a8909 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  File "/Users/sam/dev/github/tvm/src/runtime/c_runtime_api.cc", line 133
TVMError: Check failed: allow_missing: Device API gpu is not enabled.

Can anyone help me debug my way through this?

But on the other hand:

Seems to run just fine by changing .cc to .mm and uncommenting the metal defines.

So it seems that Metal is kinda working.

to run things on metal, you want to change tvm.gpu(0) to tvm.metal(0), because gpu refers to the cuda gpu due to historical reasons

If you read the pasted code you would see the TVM context with .gpu(0) is commented out and it is tvm.metal(0)

Seems it is due to the fact that a special metal schedule was not registered, the code fallback into the gpu path(which was registered via cuda) and there is a special tensorcore related check in the GPU path.

cc @Hzfengsy @Shawn_Inspur, can you look a bit into how can we skip the nvcc.have_tensorcore check when CUDA is not available?

Here is a duplicate from another point:

Running the following command

python ~/dev/tvm_test/from_tensorflow_metal.py

Results in a Backtrace

Traceback (most recent call last):

  File "/Users/sam/dev/tvm_test/metal_tf_demo.py", line 75, in <module>
    params=params)

  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/build_module.py", line 251, in build
    graph_json, mod, params = bld_mod.build(mod, target, target_host, params)

  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/build_module.py", line 120, in build
    self._build(mod, target, target_host)

  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) 9   libtvm.dylib                        0x00000001159c6e5e tvm::relay::MixedModeMutator::VisitExpr_(tvm::relay::CallNode const*) + 14
  [bt] (7) 8   libtvm.dylib                        0x00000001159c8cf9 tvm::RelayExpr tvm::relay::MixedModeMutator::Rewrite<tvm::relay::CallNode>(tvm::relay::CallNode const*) + 57
  [bt] (6) 7   libtvm.dylib                        0x00000001159c7c82 tvm::relay::ForwardRewriter::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&) + 2082
  [bt] (5) 6   libtvm.dylib                        0x00000001159c91fe tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::operator()<tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void>&, tvm::runtime::ObjectRef>(tvm::relay::Call const&&&, tvm::Array<tvm::RelayExpr, void>&&&, tvm::runtime::ObjectRef&&) const + 254
  [bt] (4) 5   libtvm.dylib                        0x000000011596c9fd std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 109
  [bt] (3) 4   libtvm.dylib                        0x000000011596ca92 void tvm::runtime::detail::unpack_call_dispatcher<tvm::RelayExpr, 0, 3, tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(tvm::RelayExpr (* const&)(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&) + 82
  [bt] (2) 3   libtvm.dylib                        0x0000000115963b14 tvm::RelayExpr tvm::relay::LayoutRewriter<tvm::relay::alter_op_layout::AlterTransformMemorizer>(tvm::relay::Call const&, tvm::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&) + 3364
  [bt] (1) 2   libtvm.dylib                        0x00000001159675a4 tvm::relay::alter_op_layout::AlterTransformMemorizer::CallWithNewLayouts(tvm::relay::Call const&, std::__1::vector<tvm::RelayExpr, std::__1::allocator<tvm::RelayExpr> > const&) + 1284
  [bt] (0) 1   libtvm.dylib                        0x0000000115bfe795 std::__1::__function::__func<TVMFuncCreateFromCFunc::$_2, std::__1::allocator<TVMFuncCreateFromCFunc::$_2>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 213
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun
    rv = local_pyfunc(*pyargs)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/op/nn/_nn.py", line 97, in alter_op_layout_conv2d
    return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
  File "<decorator-gen-35>", line 2, in conv2d_alter_layout
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/target/generic_func.py", line 267, in dispatch_func
    return dispatch_dict[k](*args, **kwargs)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/topi-0.7.dev1-py3.7.egg/topi/cuda/conv2d_alter_op.py", line 39, in _alter_conv2d_layout
    relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/backend/compile_engine.py", line 183, in select_implementation
    all_impls = get_valid_implementations(op, attrs, inputs, out_type, target)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/backend/compile_engine.py", line 124, in get_valid_implementations
    strategy = fstrategy(attrs, inputs, out_type, target)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/target/generic_func.py", line 45, in __call__
    return _ffi_api.GenericFuncCallFunc(self, *args)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
    raise get_last_ffi_error()
  [bt] (5) 6   ???                                 0x00007ffee6c70d30 0x0 + 140732770225456
  [bt] (4) 5   _ctypes.cpython-37m-darwin.so       0x0000000109a5b36f ffi_call_unix64 + 79
  [bt] (3) 4   libtvm.dylib                        0x0000000115bfcbd6 TVMFuncCall + 70
  [bt] (2) 3   libtvm.dylib                        0x00000001155e2955 std::__1::__function::__func<tvm::$_5, std::__1::allocator<tvm::$_5>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 181
  [bt] (1) 2   libtvm.dylib                        0x00000001155e0687 tvm::GenericFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const + 743
  [bt] (0) 1   libtvm.dylib                        0x0000000115bfe795 std::__1::__function::__func<TVMFuncCreateFromCFunc::$_2, std::__1::allocator<TVMFuncCreateFromCFunc::$_2>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 213
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 78, in cfun
    rv = local_pyfunc(*pyargs)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/relay/op/strategy/cuda.py", line 125, in conv2d_strategy_cuda
    if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/runtime_ctypes.py", line 218, in compute_version
    self.device_type, self.device_id, 4)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/runtime_ctypes.py", line 180, in _GetDeviceAttr
    device_type, device_id, attr_id)
  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/tvm-0.7.dev1-py3.7-macosx-10.13-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
    raise get_last_ffi_error()
  [bt] (6) 7   ???                                 0x00007ffee6c6f640 0x0 + 140732770219584
  [bt] (5) 6   _ctypes.cpython-37m-darwin.so       0x0000000109a5b36f ffi_call_unix64 + 79
  [bt] (4) 5   libtvm.dylib                        0x0000000115bfcbd6 TVMFuncCall + 70
  [bt] (3) 4   libtvm.dylib                        0x0000000115bfecb0 std::__1::__function::__func<$_4, std::__1::allocator<$_4>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 400
  [bt] (2) 3   libtvm.dylib                        0x0000000115bfde14 tvm::runtime::DeviceAPIManager::GetAPI(int, bool) + 532
  [bt] (1) 2   libtvm.dylib                        0x0000000115bfe0a5 tvm::runtime::DeviceAPIManager::GetAPI(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, bool) + 421
  [bt] (0) 1   libtvm.dylib                        0x0000000115197829 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  File "/Users/sam/dev/github/tvm/src/runtime/c_runtime_api.cc", line 133
TVMError: Check failed: allow_missing: Device API gpu is not enabled.

What is concerning is that CUDA code is being called when tvm.gpu(0) and tvm.target.cuda() are not mentioned

  File "/Users/sam/dev/github/tvm/build-runtime/python-tvm/venv/lib/python3.7/site-packages/topi-0.7.dev1-py3.7.egg/topi/cuda/conv2d_alter_op.py", line 39, in _alter_conv2d_layout
    relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target)

see

(venv) kaosnew:build sam$ grep -rn cuda ~/dev/tvm_test/metal_tf_demo.py
(venv) kaosnew:build sam$ grep -rn gpu ~/dev/tvm_test/metal_tf_demo.py

These results are not found

The contents of the script is as below

# tvm, relay
import tvm
from tvm import te
from tvm import relay

# os and numpy
import numpy as np
import os.path

# Tensorflow imports
import tensorflow.compat.v1 as tf
#import tensorflow as tf
tf_compat_v1 = tf
tf.disable_v2_behavior()

# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing

# Base location for model related files.
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'

# Test image
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)

model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)

# Image label map
map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
map_proto_url = os.path.join(repo_base, map_proto)

# Human readable text for labels
label_map = 'imagenet_synset_to_human_label_map.txt'
label_map_url = os.path.join(repo_base, label_map)

# Target settings
target = 'metal'
target_host = 'llvm'
layout = None # "NCHW"
ctx = tvm.context(target, 1)
from tvm.contrib.download import download_testdata

img_path = download_testdata(image_url, img_name, module='data')
model_path = download_testdata(model_url, model_name, module=['tf', 'InceptionV1'])
map_proto_path = download_testdata(map_proto_url, map_proto, module='data')
label_path = download_testdata(label_map_url, label_map, module='data')

with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
    graph_def = tf_compat_v1.GraphDef()
    graph_def.ParseFromString(f.read())
    graph = tf.import_graph_def(graph_def, name='')
    # Call the utility to import the graph definition into default graph.
    graph_def = tf_testing.ProcessGraphDefParam(graph_def)
    # Add shapes to the graph.
    with tf_compat_v1.Session() as sess:
        graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')

from PIL import Image
image = Image.open(img_path).resize((299, 299))

x = np.array(image)

shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
mod, params = relay.frontend.from_tensorflow(graph_def,
                                             layout=layout,
                                             shape=shape_dict)

print("Tensorflow protobuf imported to relay frontend.")
with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(mod,
                                     target=target,
                                     target_host=target_host,
                                     params=params)



from tvm.contrib import graph_runtime
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))

predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)

# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
                                    uid_lookup_path=label_path)

# Print top 5 predictions from TVM output.
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
    human_string = node_lookup.id_to_string(node_id)
    score = predictions[node_id]
    print('%s (score = %.5f)' % (human_string, score))


def create_graph():
    """Creates a graph from saved GraphDef file and returns a saver."""
    # Creates graph from saved graph_def.pb.
    with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
        graph_def = tf_compat_v1.GraphDef()
        graph_def.ParseFromString(f.read())
        graph = tf.import_graph_def(graph_def, name='')
        # Call the utility to import the graph definition into default graph.
        graph_def = tf_testing.ProcessGraphDefParam(graph_def)

def run_inference_on_image(image):
    """Runs inference on an image.

    Parameters
    ----------
    image: String
        Image file name.

    Returns
    -------
        Nothing
    """
    if not tf_compat_v1.gfile.Exists(image):
        tf.logging.fatal('File does not exist %s', image)
    image_data = tf_compat_v1.gfile.GFile(image, 'rb').read()

    # Creates graph from saved GraphDef.
    create_graph()

    with tf_compat_v1.Session() as sess:
        softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
        predictions = sess.run(softmax_tensor,
                               {'DecodeJpeg/contents:0': image_data})

        predictions = np.squeeze(predictions)

        # Creates node ID --> English string lookup.
        node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
                                            uid_lookup_path=label_path)

        # Print top 5 predictions from tensorflow.
        top_k = predictions.argsort()[-5:][::-1]
        print ("===== TENSORFLOW RESULTS =======")
        for node_id in top_k:
            human_string = node_lookup.id_to_string(node_id)
            score = predictions[node_id]
            print('%s (score = %.5f)' % (human_string, score))

run_inference_on_image(img_path)

Is it that contrib.graph_runtime, has only ever been tested with CUDA and is making assumptions?

Here are the changes from the original found here: https://tvm.apache.org/docs/tutorials/frontend/from_tensorflow.html#sphx-glr-download-tutorials-frontend-from-tensorflow-py

(venv) kaosnew:build sam$ diff from_tensorflow_llvm.py from_tensorflow_metal.py
76c76
< target = 'llvm'
---
> target = 'metal'
79c79
< ctx = tvm.cpu(0)
---
> ctx = tvm.metal(0)

@haichen under the strategy design, shall we create a separate strategy for tensor-core related code and only register it for cuda(not gpu so it won’t affect other gpu kinds like metal and opencl)?

Hey @tqchen sorry for being all cold and prickly.

It wasn’t my intent, I am just trying to decipher what I am doing wrong and how TVM is supposed to work.

I got to the point of backing myself to the point where I could see that I was doing the correct thing but the code was going down a curious path with my intention.

Thanks for getting on it so quickly, this will be really awesome once it gets working. Sam

Yes, I agree. We should have a separate strategy for cuda conv2d as it contains implementation that only applies to cuda. But we should reuse the implementations that are generally applicable to gpu target.

I can prepare a PR to fix this in a few days.

1 Like

Yes, we are considering how to fix this issue.

It is better to have a separate strategy for TensorCore and only register it for CUDA, as you suggested.

Interesting here is my machine which has no NVIDIA hardware

(venv) kaosnew:build sam$ python
Python 3.7.3 (default, Jun 19 2019, 07:40:15) 
[Clang 10.0.0 (clang-1000.11.45.5)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import tvm.contrib.nvcc as nvcc
>>> print(nvcc.find_cuda_path())
/usr/local/cuda
>>> print (nvcc.get_cuda_version(nvcc.find_cuda_path()))
9.0
>>> print(nvcc.find_libdevice_path(64))
/usr/local/cuda/nvvm/libdevice/libdevice.10.bc
>>> 

There is a simple and direct way to fix this issue. Checking both the target_name and nvcc.have_tensorcore as below instead of only checking nvcc.have_tensorcore,

if target.target_name == "cuda" and nvcc.have_tensorcore(tvm.gpu(0).compute_version): 
   # tensorcore codes...

The codes regarding both conv2d and dense should be modified.

In this case, the nvcc.have_tensorcore check will be skipped when CUDA is not available.

Hey @Shawn_Inspur I will make that on the .egg and see if it blends.

The elephant is the room is solved

WARNING:strategy:For x86 target, NCHW layout is recommended for conv2d.
WARNING:strategy:For x86 target, NCHW layout is recommended for conv2d.
WARNING:strategy:For x86 target, NCHW layout is recommended for conv2d.
WARNING:strategy:For x86 target, NCHW layout is recommended for conv2d.
WARNING:strategy:For x86 target, NCHW layout is recommended for conv2d.
WARNING:strategy:For x86 target, NCHW layout is recommended for conv2d.
African elephant, Loxodonta africana (score = 0.58335)
tusker (score = 0.33901)
Indian elephant, Elephas maximus (score = 0.02391)
banana (score = 0.00025)
vault (score = 0.00021)
===== TENSORFLOW RESULTS =======
African elephant, Loxodonta africana (score = 0.58394)
tusker (score = 0.33909)
Indian elephant, Elephas maximus (score = 0.03186)
banana (score = 0.00022)
desk (score = 0.00019)
(venv) kaosnew:topi sam$ git diff
diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index db03c59..10141d6 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -87,7 +87,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
         raise ValueError("dilation should be positive value")
 
     if groups == 1:
-        if layout == "NCHW":
+        if target.target_name == "cuda" and layout == "NCHW":
             assert kernel_layout == "OIHW"
             if data.dtype in ('int8', 'uint8') and kernel.dtype in ('int8', 'uint8'):
                 assert data.dtype == kernel.dtype
@@ -108,13 +108,13 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
                     wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
                     name="conv2d_nchw_winograd.cuda",
                     plevel=5)
-        elif layout == "HWCN":
+        elif target.target_name == "cuda" and layout == "HWCN":
             assert kernel_layout == "HWIO"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
                 wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
                 name="conv2d_hwcn.cuda")
-        elif layout == "NHWC":
+        elif target.target_name == "cuda" and layout == "NHWC":
             assert kernel_layout == "HWIO"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
@@ -122,7 +122,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
                 name="conv2d_nhwc.cuda")
             N, _, _, _ = get_const_tuple(data.shape)
             _, _, CI, CO = get_const_tuple(kernel.shape)
-            if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
+            if target.target_name == "cuda" and nvcc.have_tensorcore(tvm.gpu(0).compute_version):
                 if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
                         (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
                         (N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
@@ -131,7 +131,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
                         wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore),
                         name="conv2d_nhwc_tensorcore.cuda",
                         plevel=20)
-        elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
+        elif target.target_name == "cuda" and layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
             assert kernel_layout == "OIHW4o4i"
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),

related https://github.com/apache/incubator-tvm/issues/5370 the original problem should be fixed by now