[TVM Tensorflow frontend] argwhere op does not support dynamic output shape?

Hi, all, According to TVM tensorflow frontend code, Op argwhere is already supprotedargwhere PR, but when i converted a tensorflow model into TVM relay IR recently, the tensorflow frontend failed to do the conversion. the traceback is as follows:

Traceback (most recent call last):

File “from_hfnet_to_tvm.py”, line 118, in
outputs=output_nodes)

File “/home/admin/tvm/python/tvm/relay/frontend/tensorflow.py”, line 2387, in from_tensorflow
mod, params = g.from_tensorflow(graph, layout, shape, outputs)

File “/home/admin/tvm/python/tvm/relay/frontend/tensorflow.py”, line 2050, in from_tensorflow
out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]]

File “/home/admin/tvm/python/tvm/relay/frontend/tensorflow.py”, line 2050, in
out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]]

File “/home/admin/tvm/python/tvm/relay/frontend/common.py”, line 466, in infer_shape
out_shapes = get_const_tuple(out_type.checked_type.shape)

File “/home/admin/tvm/topi/python/topi/util.py”, line 164, in get_const_tuple
return tuple(get_const_int(elem) for elem in in_tuple)

File “/home/admin/tvm/topi/python/topi/util.py”, line 164, in
return tuple(get_const_int(elem) for elem in in_tuple)

File “/home/admin/tvm/topi/python/topi/util.py”, line 101, in get_const_int
expr = tvm.ir_pass.Simplify(expr)

File “/home/admin/tvm/python/tvm/_ffi/_ctypes/function.py”, line 210, in call
raise get_last_ffi_error()

tvm.ffi.base.TVMError: Traceback (most recent call last):
[bt] (6) /home/admin/tvm/build/libtvm.so(TVMFuncCall+0x61) [0x7f694847e8c1]
[bt] (5) /home/admin/tvm/build/libtvm.so(+0x44ae9c) [0x7f6947c7ae9c]
[bt] (4) /home/admin/tvm/build/libtvm.so(tvm::ir::Simplify(tvm::Expr, tvm::Map<tvm::Var, tvm::Range, void, void>)+0x21f) [0x7f6947d23faf]
[bt] (3) /home/admin/tvm/build/libtvm.so(tvm::arith::Analyzer::Simplify(tvm::Expr const&)+0x1e8) [0x7f6947d87e58]
[bt] (2) /home/admin/tvm/build/libtvm.so(tvm::arith::RewriteSimplifier::operator()(tvm::Expr const&)+0xa9) [0x7f6947d26779]
[bt] (1) /home/admin/tvm/build/libtvm.so(tvm::IRFunctor<tvm::Expr (tvm::NodeRef const&, tvm::Expr const&, tvm::ir::IRMutator*)>::operator()(tvm::NodeRef const&, tvm::Expr const&, tvm::ir::IRMutator*) const+0x10a) [0x7f6947cdc2ba]
[bt] (0) /home/admin/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7f6947c7d012]
File “/home/admin/tvm/include/tvm/node/ir_functor.h”, line 91
TVMError: Check failed: type_index < func
.size() && func_[type_index] != nullptr: IRFunctor calls un-registered function on type Any

After some time diving into the code, i think the problem might be that: output shape of Op argwhere is dynamic,while the _infer_shape function requires to return a const tuple, thus causing the problem.

So my question is: does TVM tensorflow frontend support Ops with dynamic output shape, such as argwhere, TopKV2, which TVM TOPI has already supported? IF NOT, is there a workaround to import tensorflow model that contains such Ops into TVM Relay IR?

Here is the script i used:


# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Compile Tensorflow Models
=========================
This article is an introductory tutorial to deploy tensorflow models with TVM.

For us to begin with, tensorflow python module is required to be installed.

Please refer to https://www.tensorflow.org/install
"""

# tvm, relay
import tvm
from tvm import relay

# os and numpy
import numpy as np
import os.path

# Tensorflow imports
import tensorflow as tf
tf.contrib.resampler

# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing

# Base location for model related files.
repo_base = '.'



######################################################################
# Tutorials
# ---------
# Please refer docs/frontend/tensorflow.md for more details for various models
# from tensorflow.

model_name = 'transformed_model.pb'
model_path = os.path.join(repo_base, model_name)





# Target settings
# Use these commented settings to build for cuda.
target = 'cuda'
target_host = 'llvm'
layout = "NCHW"
ctx = tvm.gpu(0)




######################################################################
# Import model
# ------------
# Creates tensorflow graph definition from protobuf file.
output_nodes = ['local_descriptors','scores','keypoints','global_descriptor']
with tf.io.gfile.GFile(model_path, 'rb') as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())
    graph = tf.import_graph_def(graph_def, name='')
    # Call the utility to import the graph definition into default graph.
    graph_def = tf_testing.ProcessGraphDefParam(graph_def)
    # Add shapes to the graph.
    # with tf.Session() as sess:
    #     graph_def = tf_testing.AddShapesToGraphDef(sess, output_nodes)

######################################################################
# Decode image
# ------------
# .. note::
#
#   tensorflow frontend import doesn't support preprocessing ops like JpegDecode.
#   JpegDecode is bypassed (just return source node).
#   Hence we supply decoded frame to TVM instead.
#

# from PIL import Image
# image = Image.open(img_path).resize((640, 480))

x = np.random.rand(1,480,640,1)

######################################################################
# Import the graph to Relay
# -------------------------
# Import tensorflow graph definition to relay frontend.
#
# Results:
#   sym: relay expr for given tensorflow protobuf.
#   params: params converted from tensorflow params (tensor protobuf).
shape_dict = {'pred/strided_slice_3': x.shape}
dtype_dict = {'pred/strided_slice_3': 'float32'}

mod, params = relay.frontend.from_tensorflow(graph_def,
                                             layout=layout,
                                             shape=shape_dict,
                                             outputs=output_nodes)

print("Tensorflow protobuf imported to relay frontend.")
######################################################################
# Relay Build
# -----------
# Compile the graph to llvm target with given input specification.
#
# Results:
#   graph: Final graph after compilation.
#   params: final params after compilation.
#   lib: target library which can be deployed on target with TVM runtime.

with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(mod,
                                     target=target,
                                     target_host=target_host,
                                     params=params)

######################################################################
# Execute the portable graph on TVM
# ---------------------------------
# Now we can try deploying the compiled model on target.

from tvm.contrib import graph_runtime
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
#tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))

######################################################################
# Process the output
# ------------------
# Process the model output to human readable text for InceptionV1.
# predictions = tvm_output.asnumpy()
# predictions = np.squeeze(predictions)

# # Creates node ID --> English string lookup.
# node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
#                                     uid_lookup_path=label_path)

# # Print top 5 predictions from TVM output.
# top_k = predictions.argsort()[-5:][::-1]
# for node_id in top_k:
#     human_string = node_lookup.id_to_string(node_id)
#     score = predictions[node_id]
#     print('%s (score = %.5f)' % (human_string, score))

######################################################################
# Inference on tensorflow
# -----------------------
# Run the corresponding model on tensorflow

def create_graph():
    """Creates a graph from saved GraphDef file and returns a saver."""
    # Creates graph from saved graph_def.pb.
    with tf.gfile.FastGFile(model_path, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        graph = tf.import_graph_def(graph_def, name='')
        # Call the utility to import the graph definition into default graph.
        graph_def = tf_testing.ProcessGraphDefParam(graph_def)

def run_inference_on_image(image):
    """Runs inference on an image.

    Parameters
    ----------
    image: String
        Image file name.

    Returns
    -------
        Nothing
    """
    if not tf.gfile.Exists(image):
        tf.logging.fatal('File does not exist %s', image)
    image_data = tf.gfile.FastGFile(image, 'rb').read()

    # Creates graph from saved GraphDef.
    create_graph()

    with tf.Session() as sess:
        softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
        predictions = sess.run(softmax_tensor,
                               {'DecodeJpeg/contents:0': image_data})

        predictions = np.squeeze(predictions)

        # Creates node ID --> English string lookup.
        node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
                                            uid_lookup_path=label_path)

        # Print top 5 predictions from tensorflow.
        top_k = predictions.argsort()[-5:][::-1]
        print ("===== TENSORFLOW RESULTS =======")
        for node_id in top_k:
            human_string = node_lookup.id_to_string(node_id)
            score = predictions[node_id]
            print('%s (score = %.5f)' % (human_string, score))

#run_inference_on_image(img_path)

I think that’s because of the function TransformShape in the file data_layout.cc. It’s used to split shape like from NCHW to NCHWc. But when you want to convert your output shape from NCHW to NHWC for example, it will cause error. So you need to add your own shape transform function.