Type mismatch about static tensor array

Consider this case:

input_shape = (5, 5, 800)
infer_shape = (5, 800)
t = tf.constant(np.random.choice([0, 1, 2, 3],
ta1 = tf.TensorArray(dtype=dtype, infer_shape=infer_shape, size=input_shape[0],
                     element_shape=tf.TensorShape([tf.Dimension(None), 800]))
ta2 = ta1.unstack(t)
out1 = ta2.read(0)

Relay will create a GlobalVar for tensor array as static_tensor_array_?_800_t , the ? comes from tensor shape (?, 800)

But when constructing tensor_array_scatter_func, it will create a type static_tensor_array_5_800_t

These two types can’t be unified.

There maybe two solutions for this

  1. Recreate tensor array when it’s elem_shape doesn’t match what tensor_array_scatter_func needs.
  2. Enhance type solver to handle this case.

@kevinthesun Could you please check this case and share your thought? Thanks a lot

I’m refactoring tf frontend tensor array and this will be fixed.