Operator MUL,RESIZE_BILLINEAR not implemented in frontend TFlite

when I try a pretrained tflite model, get a error:
tvm.error.OpNotImplemented: The following operators are not supported in frontend TFLite: ‘MUL’,‘RESIZE_BILINEAR’

I encourage buddy to contribute. It is not difficult. However, I could implement them after this PR is done. https://github.com/dmlc/tvm/pull/3141#issuecomment-492480777

ok,i got it ,thanks…

sorry, my codes has errors,can you help me?thanks :slight_smile:

   def convert_resizebilinear(self, op):
    """Convert TFLite resizebilinear"""
    try:
        from tflite.BuiltinOptions import BuiltinOptions
        from tflite.Operator import Operator
        from tflite.ResizeBilinearOptions import ResizeBilinearOptions
    except ImportError:
        raise ImportError("The tflite package must be installed")
    assert isinstance(op, Operator)
    input_tensors = self.get_input_tensors(op)
    output_tensors = self.get_output_tensors(op)
    assert len(input_tensors) == 2, "input tensors length should be 2"
    #assert len(output_tensors) == 2, "output tensors length should be 2"
    input_tensor = input_tensors[0]
    input_tensor_idx = input_tensor.tensor_idx

    assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions
    op_options = op.BuiltinOptions()
    resizebilinear_options = ResizeBilinearOptions()
    resizebilinear_options.Init(op_options.Bytes, op_options.Pos)

    input_shape= input_tensor.tensor.ShapeAsNumpy()
    input_shape_length = len(input_tensor.tensor.ShapeAsNumpy())

    output_shape  =  output_tensors[0].tensor.ShapeAsNumpy()
    output_shape_length = len(output_tensors[0].tensor.ShapeAsNumpy())

    scaleSize = int(output_shape[1]/input_shape[1])
    
    print("************{}",input_tensors[0].tensor.ShapeAsNumpy())
    print("************{}",output_tensors[0].tensor.ShapeAsNumpy())
    print("************{}",scaleSize)

    in_expr = self.get_expr(input_tensor_idx)

    # TFLite is N H W C, our layout is N C H W
    if input_shape_length in (1, 2):
        # The rule is channel first (after N but before H, W).
        # length of 1 means N*H*W*C, do nothing.
        # length of 2 means N*H*W, C, do nothing.
        pass
    elif input_shape_length == 3:
        # convert N C H*W to N H*W C
        in_expr = _op.transpose(in_expr, axes=(0, 2, 1))
    elif input_shape_length == 4:
        # convert input to N H W C, then reshape to target shape,
        # finally convert back if necessary
        in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1))
    else:
        msg = 'Input shape length {} for operator Squeeze is not valid.'
        raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))

    out = _op.nn.upsampling(in_expr,scale=scaleSize,method="BILINEAR")

    return out

MUL op codes , is right? thanks

def convert_mul(self, op):
    """Convert TFLite mul """
    try:
        from tflite.Operator import Operator
    except ImportError:
        raise ImportError("The tflite package must be installed")

    assert isinstance(op, Operator)
    input_tensors = self.get_input_tensors(op)
    assert len(input_tensors) == 2, "input tensors length should be 2"

    lhs_tensor = input_tensors[0]
    lhs_expr = self.get_expr(lhs_tensor.tensor_idx)

    rhs_tensor = input_tensors[1]
    if self.has_expr(rhs_tensor.tensor_idx):
        # In most cases, we can assume that TOCO fuses ADD operators
        # with constants - it means both will be tensors.
        rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
    else:
        # However, in some corner cases, the ADD operator is not fused,
        # we can receive as constant.
        rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type())
        rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
                                          dtype=rhs_type_str)

        # In this case, we have to be careful about formatting.
        input_shape_length = len(rhs_tensor.tensor.ShapeAsNumpy())
        if input_shape_length in (1, 2):
            pass
        elif input_shape_length == 3:
            # N H*W C to N C H*W
            rhs_expr = _op.transpose(rhs_expr, axes=(0, 2, 1))
        elif input_shape_length == 4:
            # N H W C to N C H W
            rhs_expr = _op.transpose(rhs_expr, axes=(0, 3, 1, 2))
        else:
            msg = 'Input shape length {} for operator ADD is not valid.'
            raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))

    out = _op.nn.batch_matmul(lhs_expr, rhs_expr)
    return out

Resize’s implementation is not correct. Firstly, TFLite has two input tensors. One is input, another one is resized value (i.e. target size). You should get target size from input_tensor[1]. Another one is you miss align_corners attribute. The third is you don’t need transpose and shouldn’t use up_sampling, you should call _op.image.resize . Please refer mxnet implementation

Mul’s implementation is not correct too. You shouldn’t call batch_matmul, which is matrix multiply, not MUL. should call _op.multiply. I suggest you write unit testing to verify your implementation.

i have finished it,i can convert the tflite model to tvm, but i don’t test the detected results ,can you give me a review,thanks

 def convert_resizebilinear(self, op):
    """Convert TFLite resizebilinear"""
    try:
        from tflite.BuiltinOptions import BuiltinOptions
        from tflite.Operator import Operator
        from tflite.ResizeBilinearOptions import ResizeBilinearOptions
    except ImportError:
        raise ImportError("The tflite package must be installed")

    assert isinstance(op, Operator)
    input_tensors = self.get_input_tensors(op)
    output_tensors = self.get_output_tensors(op)
    assert len(input_tensors) == 2, "input tensors length should be 2"
    assert len(output_tensors) == 1, "output tensors length should be 1"
   
    input_tensor = input_tensors[0]
    input_tensor_idx = input_tensor.tensor_idx
    
    assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions
    op_options = op.BuiltinOptions()
    resizebilinear_options = ResizeBilinearOptions()
    resizebilinear_options.Init(op_options.Bytes, op_options.Pos)
    output_shape  =  output_tensors[0].tensor.ShapeAsNumpy()
    in_expr = self.get_expr(input_tensor_idx)

    size = (output_shape[1], output_shape[2])
    #out = _op.image.resize(in_expr,size,align_corners=resizebilinear_options.AlignCorners)
    out = _op.image.resize(in_expr,size,align_corners=True)
    return out

def convert_mul(self, op):
    """Convert TFLite mul """
    try:
        from tflite.Operator import Operator
    except ImportError:
        raise ImportError("The tflite package must be installed")

    assert isinstance(op, Operator)
    input_tensors = self.get_input_tensors(op)
    assert len(input_tensors) == 2, "input tensors length should be 2"

    lhs_tensor = input_tensors[0]
    lhs_expr = self.get_expr(lhs_tensor.tensor_idx)

    rhs_tensor = input_tensors[1]
    if self.has_expr(rhs_tensor.tensor_idx):
        # In most cases, we can assume that TOCO fuses ADD operators
        # with constants - it means both will be tensors.
        rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
    else:
        # However, in some corner cases, the ADD operator is not fused,
        # we can receive as constant.
        rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type())
        rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
                                          dtype=rhs_type_str)

        # In this case, we have to be careful about formatting.
        input_shape_length = len(rhs_tensor.tensor.ShapeAsNumpy())
        if input_shape_length in (1, 2):
            pass
        elif input_shape_length == 3:
            # N H*W C to N C H*W
            rhs_expr = _op.transpose(rhs_expr, axes=(0, 2, 1))
        elif input_shape_length == 4:
            # N H W C to N C H W
            rhs_expr = _op.transpose(rhs_expr, axes=(0, 3, 1, 2))
        else:
            msg = 'Input shape length {} for operator ADD is not valid.'
            raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))

    out = _op.multiply(lhs_expr,rhs_expr)
    return out

align_corners should not be set True. You should get the value from the ResizeBilinearOptions. Target Shape should be get from input_tensors[1]. You also miss the layout parameter for op.image.resize.

I suggest you could implement these ops after my pr merged. Because it will affect your implementation. For example, resize bilinear’s layout parameter / mul op’s handle way.

Lastly, you should verify the correctness writing unit testing.

thanks, i can not get the target shape from the input_tensors[1],i have tested it:

print("*****input_tensor[0]=",input_tensors[0].tensor.ShapeAsNumpy())
print("*****input_tensor[1]=",input_tensors[1].tensor.ShapeAsNumpy())

the results show:

*****input_tensor[0]= [ 1 32 32 64]
*****input_tensor[1]= [2]

so i just get it from the output_tensor as follows:

output_tensors = self.get_output_tensors(op)

i do not know how to get target shape from input_tensors[1],can you help me ?

i have a question:
we have konw that as above result: input_tensor[0]= [ 1 32 32 64]
so the input tensor layout is “NHWC”,
but i set op.image.resize() : layout=“NCHW”, it works well
i set layout=“NHWC”,it gives an error
why?

another question:
if your PR has merged, the layout =“NHWC”,but autotvm just supports “NCHW”,so how i use the autotvm? thanks

You should call self.get_tensor_value for input_tensor[1] to get value.

Because our input[0] symbol have been translated in NCHW in previous layer, your print is just TFLite’s result, not TVM, so the layout should be set NCHW.

Good question. If my pr is merged, you could not use AutoTVM. You have to wait I implement Spatial Pack schedule for NHWC. Or I encourage you to implement too. You could refer current implementation, it is not hard.

thanks for your help,currently can i use autotvm to tune the tflite model? just now i encounter some problems in tuning the tflite model by autotvm.

of course you could tune.

Hi, FrozenGene:
I met a problem when converting uint8 tflite mode to relay IR.
Can you give me some advice? Thanks!

Traceback (most recent call last):
File “/usr/lib/python3.5/pdb.py”, line 1661, in main
pdb._runscript(mainpyfile)
File “/usr/lib/python3.5/pdb.py”, line 1542, in _runscript
self.run(statement)
File “/usr/lib/python3.5/bdb.py”, line 431, in run
exec(cmd, globals, locals)
File “”, line 1, in
File “/mnt/fu02/zjzhang/tvm_fu/tvm/tests/futest/cpu/uint8/linux/from_tflite.py”, line 63, in
graph, lib, params = relay.build(func, target, params=params)
File “/mnt/fu02/zjzhang/tvm_fu/tvm/python/tvm/relay/build_module.py”, line 276, in build
func = optimize(func, target, params)
File “/mnt/fu02/zjzhang/tvm_fu/tvm/python/tvm/relay/build_module.py”, line 163, in optimize
func = ir_pass.infer_type(func)
File “/mnt/fu02/zjzhang/tvm_fu/tvm/python/tvm/relay/ir_pass.py”, line 353, in infer_type
return _ir_pass.infer_type(expr, mod)
File “/mnt/fu02/zjzhang/tvm_fu/tvm/python/tvm/_ffi/_ctypes/function.py”, line 190, in call
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):

In main:
v0.0.1
%209 = fn (%Image: Tensor[(1, 224, 224, 3), uint8]) {
%0 = nn.pad(%Image, pad_width=[[0, 0], [3, 3], [3, 3], [0, 0]]) //
%1 = nn.conv2d(%0, meta[relay.Constant][0] // , strides=[4, 4], channels=32, kernel_size=[7, 7], data_layout=“NHWC”, kernel_layout=“HWIO”) //
%2 = nn.bias_add(%1, meta[relay.Constant][1] // , axis=3) // an internal invariant was violdated while typechecking your program [13:33:40] /mnt/fu02/zjzhang/tvm_fu/tvm/src/relay/pass/type_solver.cc:100: Check failed: resolved.defined(): Unable to unify parent types: TensorType([32], uint8) and TensorType([32], int32)
;
%3 = nn.pad(%2, pad_width=[[0, 0], [0, 0], [0, 0], [0, 0]]) //
%4 = nn.conv2d(%3, meta[relay.Constant][2] // , strides=[2, 2], channels=196, kernel_size=[1, 1], data_layout=“NHWC”, kernel_layout=“HWIO”) //
%5 = nn.bias_add(%4, meta[relay.Constant][3] // , axis=3) //

The quantized model support is far from done as of now :slight_smile: We are discussing the proposal here.