The following simple code snippet will generate incorrect result:
import os
import numpy as np
import tvm
from tvm import relay
import tvm.relay.op as _op
from tvm.ir import IRModule
import tvm.relay.expr as _expr
from tvm.runtime.vm import VirtualMachine
dshape = (1, 2, 8)
data = relay.var("data", shape=dshape, dtype="float32")
o0 = _op.layout_transform(data, "NC8c", "NC")
o0 = _op.reshape(o0, [-1, 4])
o0 = _op.transpose(o0, [1, 0])
o0 = _op.split(o0, indices_or_sections=4)
func = relay.Function([data], o0[0])
mod = IRModule()
mod["main"] = func
with relay.build_config(opt_level=3):
vm_exec = relay.vm.compile(mod, target= 'llvm')
vm = VirtualMachine(vm_exec)
ctx = tvm.cpu()
vm.init(ctx)
in_data = np.array([[[0, 1, 2, 3, 4, 5, 6, 7], [0, 1, 2, 3, 4, 5, 6, 7]]]).astype("float32")
res = vm.invoke("main", [in_data])
print(res.asnumpy())
The expected result would be [[0. 4. 0. 4.]]
but actually it returns [[0. 1. 0. 1.]]
.
This is a code snippet from a large model. I can make the model output correct by temporarily set fusion patten of split
or transpose
to be kOutEWiseFusable
.