TensorArray GlobalVar and GlobalTypeVar Confusion

I am working with the new TensorArray functionality from the TensorFlow frontend, and am getting into an area that I’m a little confused about:

It looks like all of the TensorArray functions and type definitions are global variables. However, in the TensorFlow frontend, none of those global variables are passed to _infer_shape after converting an operator. Because of that, nodes like TensorArrayUnstack cannot find the necessary TensorArray function or type definitions during type inference.

Is this expected? Am I doing something wrong?

Edit: I tried passing the global vars and global type vars to _infer_shape. I used self._prelude.mod.functions and self._prelude.mod.type_definitions. However, now the call to out_type.checked_type.shape is failing due to: tvm.relay.ty.TypeCall has no attribute shape. It seems like type inference isn’t getting the real shape.

cc @wweic

Hey @jonso, a reproduce script would be helpful for debugging.

By Looking at the code, I think we might need to pass in the module that contains all the prelude global vars to infer_type. Currently it’s creating a fresh module every time.

def infer_type(node):
    """A method to infer the type of an intermediate node in the relay graph."""
    mod = node if isinstance(node, _module.Module) else _module.Module.from_expr(node)
    mod = _transform.InferType()(mod)
    entry = mod["main"]
    return entry if isinstance(node, _expr.Function) else entry.body


def infer_shape(inputs):
    """A method to get the output shape of an intermediate node in the graph."""
    out_type = infer_type(inputs)
    out_shapes = get_const_tuple(out_type.checked_type.shape)
    return out_shapes

Can you give an example of what you mean? I was originally trying to pass the functions and type_defs fields, but it gave the error mentioned above.

I’m thinking of the change like this: https://github.com/wweic/tvm/commit/f2bc89e15be89366e130181115070888cdf0c1aa

Basically when we infer the type of an expression, we should try to put the expression inside a module that has all the definitions.

Got it.

I tried implementing this change, but it still seems that the main function of the module returned by infer_type is a CallNode that points to GlobalVar(tensor_array_scatter_float32). Its checked_type is a TypeCall, which doesn’t seem to be evaluating to a constant.

Btw, I will work on getting a minimum reproducible script. It is a bit tough though, since I’m working in the context of a large model. This node is in the context of a dynamic LSTM.

Here is an simple dynamic LSTM example that throws an error. It looks like the first handful of errors can be resolved by using _infer_value_simulated that was recently added. After that, I am seeing a different error.

Let me know if this works for you. I think this would be a great scenario to support in TVM. Thanks a lot for the help :slight_smile:

import tvm
import numpy as np
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from tvm import relay
from tvm.contrib import graph_runtime

pb_name = "lstm.pb"
hidden_dim = 4
shape_0 = (1,4,4) # input (batch, seq, input)
shape_1 = (1,) # seq len

shape_dict = {"Placeholder" : shape_0, "Placeholder_1" : shape_1}

input_0_np = np.random.random(size=shape_0)
input_1_np = np.random.randint(1, 3, size=shape_1)

def _remove_assert(all_nodes):
    all_nodes_dict = {}
    for node in all_nodes:
        all_nodes_dict[node.name] = node

    new_nodes = []
    for i,node in enumerate(all_nodes):
        if "assert" in node.name.lower():
            continue

        new_inputs = []
        for inp in node.input:
            if "assert" in inp.lower():
                continue
            else:
                new_inputs.append(inp)

        del node.input[:]
        node.input.extend(new_inputs)
        new_nodes.append(node)

    graph_def = graph_pb2.GraphDef()
    graph_def.node.extend(new_nodes)
    return graph_def

def create_pb():
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.float32, shape=shape_0)
        y = tf.placeholder(tf.int32, shape=shape_1)

        lstm_cell = tf.nn.rnn_cell.LSTMCell(hidden_dim)
        output, (c_state, h_state) = tf.nn.dynamic_rnn(lstm_cell, x, y, dtype=tf.float32)
        output_add = tf.add(output, output)
        c_state_add = tf.add(c_state, c_state)
        h_state_add = tf.add(h_state, h_state)

        with tf.gfile.GFile(pb_name, "wb") as f:
            sess = tf.Session(graph = graph)
            sess.run(tf.global_variables_initializer())
            graph_def = graph.as_graph_def(add_shapes=True)
            graph_def = _remove_assert(graph_def.node)
            graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, ["Add", "Add_1", "Add_2"])
            graph_def = TransformGraph(
                    graph_def, # graph def
                    ["Placeholder", "Placeholder_1"], # inputs
                    ["Add", "Add_1", "Add_2"], # outputs
                    ["strip_unused_nodes",
                     "sort_by_execution_order",
                     "fold_batch_norms",
                     "sort_by_execution_order",
                     "fold_old_batch_norms",
                     "sort_by_execution_order",
                     ]# transforms
            )
            f.write(graph_def.SerializeToString())

def get_graph():
    with tf.gfile.GFile(pb_name, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        graph = tf.import_graph_def(graph_def, name = "")
        return graph_def, graph

def run_tf(graph):
    with tf.Session(graph = graph) as sess:
        output_tensor_0 = tf.get_default_graph().get_tensor_by_name("Add" + ":0")
        output_tensor_1 = tf.get_default_graph().get_tensor_by_name("Add_1" + ":0")
        output_tensor_2 = tf.get_default_graph().get_tensor_by_name("Add_2" + ":0")
        placeholder_tensor = tf.get_default_graph().get_tensor_by_name("Placeholder:0")
        placeholder_1_tensor = tf.get_default_graph().get_tensor_by_name("Placeholder_1:0")
        output = sess.run([output_tensor_0, output_tensor_1, output_tensor_2], { placeholder_tensor : input_0_np, placeholder_1_tensor : input_1_np })

def run_tvm(graph_def):
    print("Before importing...")
    sym, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict, outputs=["Add", "Add_1", "Add_2"])
    print("Finished from tensorflow")

    with relay.build_config(opt_level=3):
        graph, lib, params = relay.build(sym, target="llvm -mcpu=core-avx2", params=params)

    m = graph_runtime.create(graph, lib, tvm.cpu())

    m.set_input("Placeholder", input_0_np)
    m.set_input("Placeholder_1", input_1_np)
    m.set_input(**params)
    m.run()
    tvm_output_0=m.get_output(0)
    tvm_output_1=m.get_output(1)
    tvm_output_2=m.get_output(2)

create_pb()
graph_def, graph = get_graph()
run_tf(graph)
run_tvm(graph_def)

I get following error with your script(tf version 1.13.1):

Traceback (most recent call last):

  File "fail.py", line 86, in <module>
    run_tvm(graph_def)

  File "fail.py", line 67, in run_tvm
    sym, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict, outputs=["Add", "Add_1", "Add_2"])

  File "/Users/wweic/workspace/tvm/python/tvm/relay/frontend/tensorflow.py", line 2475, in from_tensorflow
    mod, params = g.from_tensorflow(graph, layout, shape, outputs)

  File "/Users/wweic/workspace/tvm/python/tvm/relay/frontend/tensorflow.py", line 2113, in from_tensorflow
    op = self._convert_operator(node.op, inputs, attr, graph)

  File "/Users/wweic/workspace/tvm/python/tvm/relay/frontend/tensorflow.py", line 2437, in _convert_operator
    sym = convert_map[op_name](inputs, attrs, self._params)

  File "/Users/wweic/workspace/tvm/python/tvm/relay/frontend/tensorflow.py", line 1052, in _impl
    else params.pop('Rank').asnumpy()[0]

KeyError: 'Rank'

This error is because input 2 of rnn/TensorArrayStack/range is not named ‘Rank’ and is not constant. This can be easily fixed by using _infer_value_simulated(inputs[1], params).asnumpy()[0]. After this, I see the following error. Does TensorArray in prelude return both outputs?

Btw, here is the description of the second output value. It’s not very helpful. Maybe we can just automatically set it to a constant of 0?
flow: A scalar used to control gradient flow.

I updated my script above to remove asserts from the graph def before exporting. That will get around the next error.

  /tvm/python/tvm/relay/frontend/tensorflow.py", line 2394, in _convert_control_flow_operator
    op = self._nodes[node.input[0]]

KeyError: 'rnn/TensorArray:1'

hmm. could you share your branch with your current fixes?

The branch I’ve been working on is here. Here and here are the commits.

@wweic let me know if you’ve had a chance to take a look. I am also available for a call if you want to debug together.

Hi, I’ll try your branch this week. :slight_smile:

It’s a common problem of using prelude as it defines many data types and global function in the module. So you have to pass module into the type inference if you use the data type defined in adt.

@jonso Could you probably work on a PR to fix this by passing the mod to the type inference in the TF frontend converter?

Hi, i have tried the solution from your branch but still when i try to compile it will lead so other errors as
Which comes from function _tensor_array_scatter()

values_rank = len(inputs[2].type_annotation.shape)

AttributeError: <class ‘tvm.relay.expr.Call’> has no attribute type_annotation

PR sent: https://github.com/apache/incubator-tvm/pull/4287/

@OriAlpha How can I reproduce your error? Is it possible to try your branch and model?

@wweic I have used the model with mnist dataset as input with model LSTM networks. you can get model by

model = Sequential() model.add(LSTM(1, input_shape=(x_train.shape[1:]), activation=‘relu’, return_sequences=True)) model.add(LSTM(1, activation=‘relu’)) model.add(Dense(32, activation=‘relu’)) model.add(Dense(10, activation=‘softmax’))

and i have used the branch of @jonso for testing with these i think you can regenerate

Do anyone have solutions for this.

KeyError: ‘Rank’

@OriAlpha my branch should have a fix for this.