Strange overhead of tvm.runtime.ndarray.from_dlpack

Hi all, I’ve noticed that the from_dlpack cost a lot when the execution time of rt_mod is brief.

rt_mod only Time: 0.033231496810913086 ms

rt_mod with dlpack Time: 0.05380082130432129 ms

dlpack only Time: 0.03383135795593262 ms

code to reproduce:

from tvm.script import ir as I
from tvm.script import tir as T
import tvm
import torch
from torch.utils.dlpack import to_dlpack

@I.ir_module
class Module:
    @T.prim_func
    def main(
        A: T.Buffer((1, 1024), "float16"),
        B: T.Buffer((1024, 512), "int8"),
        Scale: T.Buffer((1024, 1), "float16"),
        Zeros: T.Buffer((1024, 1), "float16"),
        D: T.Buffer((1, 1024), "float16"),
    ):
        T.func_attr(
            {
                "dequantize_info": {
                    "B_decode": {
                        "decode_block": "B_decode",
                        "fast_decoding": T.bool(False),
                        "group_size": 1024,
                        "source_format": {"bits": 4, "format": "uint"},
                        "storage_dtype": "int8",
                        "target_format": "float16",
                        "with_scaling": T.bool(True),
                        "with_zeros": T.bool(True),
                        "zeros_type": "original",
                    }
                },
                "tir.noalias": T.bool(True),
            }
        )
        # with T.block("root"):
        B_decode_local = T.alloc_buffer((1024, 1024), "float16", scope="local")
        A_local = T.alloc_buffer((1, 1024), "float16", scope="local")
        B_local = T.alloc_buffer((1024, 512), "int8", scope="local")
        C_local = T.alloc_buffer((1, 1024), "float16", scope="local")
        for ax0_0 in T.thread_binding(512, thread="blockIdx.x"):
            for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
                for ax1_0 in range(2):
                    for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
                        for ax0 in range(1):
                            for ax1 in T.vectorized(4):
                                with T.block("B_local"):
                                    v0 = T.axis.spatial(1024, ax0_0 * 2 + ax0_1 + ax0)
                                    v1 = T.axis.spatial(
                                        512, ax1_0 * 256 + ax1_1 * 4 + ax1
                                    )
                                    T.reads(B[v0, v1])
                                    T.writes(B_local[v0, v1])
                                    B_local[v0, v1] = B[v0, v1]
                        for ax0, ax1 in T.grid(1, 8):
                            with T.block("B_decode_local"):
                                v0 = T.axis.spatial(1024, ax0_0 * 2 + ax0_1 + ax0)
                                v1 = T.axis.spatial(1024, ax1_0 * 512 + ax1_1 * 8 + ax1)
                                T.reads(
                                    B_local[v0, v1 // 2], Zeros[v0, 0], Scale[v0, 0]
                                )
                                T.writes(B_decode_local[v0, v1])
                                B_decode_local[v0, v1] = (
                                    T.Cast(
                                        "float16",
                                        T.bitwise_and(
                                            T.shift_right(
                                                B_local[v0, v1 // 2],
                                                T.Cast("int8", v1 % 2 * 4),
                                            ),
                                            T.int8(15),
                                        ),
                                    )
                                    - Zeros[v0, 0]
                                ) * Scale[v0, 0]
                        for ax0 in range(1):
                            for ax1 in T.vectorized(8):
                                with T.block("A_local"):
                                    v0 = T.axis.spatial(1, ax0)
                                    v1 = T.axis.spatial(
                                        1024, ax1_0 * 512 + ax1_1 * 8 + ax1
                                    )
                                    T.reads(A[v0, v1])
                                    T.writes(A_local[v0, v1])
                                    A_local[v0, v1] = A[v0, v1]
                        for ax1_2_0, ax1_2_1 in T.grid(4, 2):
                            with T.block("C"):
                                v0 = T.axis.spatial(1024, ax0_0 * 2 + ax0_1)
                                v1 = T.axis.reduce(
                                    1024,
                                    ax1_0 * 512 + ax1_1 * 8 + ax1_2_0 * 2 + ax1_2_1,
                                )
                                T.reads(A_local[0, v1], B_decode_local[v0, v1])
                                T.writes(C_local[0, v0])
                                with T.init():
                                    C_local[0, v0] = T.float16(0)
                                C_local[0, v0] = (
                                    C_local[0, v0]
                                    + A_local[0, v1] * B_decode_local[v0, v1]
                                )
                for ax0, ax1 in T.grid(1, 1):
                    with T.block("C_local"):
                        v0 = T.axis.spatial(1, ax0)
                        v1 = T.axis.spatial(1024, ax0_0 * 2 + ax0_1 + ax1)
                        T.reads(C_local[v0, v1])
                        T.writes(D[0, v1])
                        D[0, v1] = C_local[v0, v1]

target = tvm.target.Target("cuda")
with tvm.transform.PassContext():
    rt_mod = tvm.build(Module, target=target)

torch_tensors = []
input_tensor = torch.randn(1, 1024).half().cuda()
weight_tensor = torch.randint(-8, 8, (1024, 512), dtype=torch.int8).cuda()
scale_tensor = torch.randn(1024).half().cuda().reshape(-1, 1)
zero_tensor = torch.zeros(1024).half().cuda().reshape(-1, 1)
output_tensor = torch. Empty(1, 1024).half().cuda()
torch_tensors.append(input_tensor)
torch_tensors.append(weight_tensor)
torch_tensors.append(scale_tensor)
torch_tensors.append(zero_tensor)
torch_tensors.append(output_tensor)

tvm_nd_array_tensors = [
    tvm.runtime.ndarray.from_dlpack(to_dlpack(torch_tensor))
    for torch_tensor in torch_tensors
]

import time
start = time. Time()
for _ in range(1000):
    rt_mod(*tvm_nd_array_tensors)
end = time. Time()
print("rt_mod only Time: ", end - start)


import time

start = time. Time()
for _ in range(1000):
    dlpack_tensors = [
        to_dlpack(torch_tensor) for torch_tensor in torch_tensors
    ]
    tvm_nd_array_tensors = [
        tvm.runtime.ndarray.from_dlpack(dlpack_tensor)
        for dlpack_tensor in dlpack_tensors
    ]
    rt_mod(*tvm_nd_array_tensors)
end = time. Time()
print("rt_mod with dlpack Time: ", end - start)


import time

start = time. Time()
for _ in range(1000):
    dlpack_tensors = [
        to_dlpack(torch_tensor) for torch_tensor in torch_tensors
    ]
    tvm_nd_array_tensors = [
        tvm.runtime.ndarray.from_dlpack(dlpack_tensor)
        for dlpack_tensor in dlpack_tensors
    ]
end = time. Time()
print("dlpack only Time: ", end - start)

do we have any methods to do runtime module forward from ptr directly?

Using offline casting of tensor types can significantly alleviate the cost. After breakdown profiling, I’ve noticed two primary overhead of runtime module forward: rt_mod.get_function(rt_mod.entry_name), and the cost of tensor reconstruction from dl_pack format.

Here’s a refined version of the code snippet:

function = rt_mod.get_function(rt_mod.entry_name)
function_handle = function.handle
num_args = 5
start = time.time()

values = (TVMValue * num_args)()
tcodes = (ctypes.c_int * num_args)()
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
for i in range(num_args):
    tcodes[i] = ArgTypeCode.NDARRAY_HANDLE

dlpack_tensors = [to_dlpack(torch_tensor) for torch_tensor in torch_tensors]

for i, dltensor in enumerate(dlpack_tensors[1:]):
    dltensor = ctypes.py_object(dltensor)
    if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor):
        ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
        # enforce type to make sure it works for all ctypes
        ptr = ctypes.cast(ptr, ctypes.c_void_p)
        handle = TVMArrayHandle()
        check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
        ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
        ctypes.pythonapi.PyCapsule_SetDestructor(
            dltensor, TVMPyCapsuleDestructor(0)
        )
        values[i + 1].v_handle = ctypes.cast(handle, ctypes.c_void_p)

for _ in range(1000):
    dltensor = dlpack_tensors[0]
    dltensor = ctypes.py_object(dltensor)
    if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor):
        ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
        # enforce type to make sure it works for all ctypes
        ptr = ctypes.cast(ptr, ctypes.c_void_p)
        handle = TVMArrayHandle()
        check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
        ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
        ctypes.pythonapi.PyCapsule_SetDestructor(
            dltensor, TVMPyCapsuleDestructor(0)
        )
        values[0].v_handle = ctypes.cast(handle, ctypes.c_void_p)
    if (
        _LIB.TVMFuncCall(
            function_handle,
            values,
            tcodes,
            ctypes.c_int(num_args),
            ctypes.byref(ret_val),
            ctypes.byref(ret_tcode),
        )
        != 0
    ):
        raise_last_ffi_error()

We can move those stuff out of runtime stage, only keep the dynamic input tensor cast during forward, if we want to get better performance with small shapes. (I encountered this issue on integrating tvm rt_mod directly into pytorch)

It’s also love to see how we can do forward directly from some device pointers instead of construct and transform data handles.

the function itself can be fetched directly once instead of calling get functions multiple times

Try to use the cython API or directly call tvm compiled with cython should help

thank tq, I appended a new ndarray cython api to directly set data_ptr and attributes:

NDArray NDArray::FromDLAttributes(void* data, DLDataType dtype, DLDevice dev, int ndim,
                                  int64_t* shape) {
  NDArray::Container* container = new NDArray::Container();
  container->SetDeleter(Internal::SelfDeleter);
  // TODO(lei): append some checks here
  // fill up content.
  DLTensor from;
  from.data = const_cast<void*>(data);
  from.device = dev;
  from.ndim = ndim;
  from.dtype = dtype;
  from.shape = shape;
  from.strides = nullptr;
  from.byte_offset = 0;
  container->dl_tensor = from;
  ICHECK(IsAligned(container->dl_tensor))
      << "Data in DLManagedTensor is not aligned as required by NDArray";
  // update shape_
  std::vector<ShapeTuple::index_type> container_shape;
  container_shape.resize(from.ndim);
  container_shape.assign(from.shape, from.shape + from.ndim);
  container->shape_ = ShapeTuple(container_shape);
  NDArray(GetObjectPtr<Object>(container));
  return NDArray(GetObjectPtr<Object>(container));
}

looks like it works:

for i, torch_tensor in enumerate(torch_tensors):
        attr_handle = TVMArrayHandle()
        data_ptr = ctypes.cast(torch_tensor.data_ptr(), ctypes.c_void_p)
        dtype = tvm.DataType(str(torch_tensor.dtype).replace('torch.', ''))
        ndim = len(torch_tensor.shape)
        shape = ctypes.cast(
            (tvm_shape_index_t * ndim)(*torch_tensor.shape), ctypes.POINTER(tvm_shape_index_t)
        )
        torch_device = torch_tensor.device
        device = tvm.runtime.device(torch_device.type, torch_device.index)
        check_call(_LIB.TVMArrayFromDLAttributes(
            data_ptr,
            dtype,
            device,
            ctypes.c_int32(ndim),
            shape,
            ctypes.byref(attr_handle)
            ))
if (
    _LIB.TVMFuncCall(
        function_handle,
        values,
        tcodes,
        ctypes.c_int(num_args),
        ctypes.byref(ret_val),
        ctypes.byref(ret_tcode),
    )
    != 0
):
    raise_last_ffi_error()

can we cross check cython’s FromDLpack, i am a bit surprised if that becomes a bottleneck

Yeah I’ve dug further of this item, let me share some discoveries.

Actually, the kernel only executes around 4us (1x1024x1024xint4b) on my A100 device, and the time evaluator reports the similar results.

tvm_nd_array_tensors = [
    tvm.runtime.ndarray.from_dlpack(to_dlpack(torch_tensor))
    for torch_tensor in torch_tensors
]

time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, tvm.cuda(), number=1000000)

latency = time_evaluator(*tvm_nd_array_tensors).mean * 1e6
print("rt_mod time_evaluator Time: {:.2f} us". Format(latency))
# rt_mod time_evaluator Time: 4.39 us

if we directly profile the rt_mod, we get around 13+us:

# warmup
for _ in range(1000):
    rt_mod(*tvm_nd_array_tensors)

start = time. Time()
for _ in range(1000000):
    rt_mod(*tvm_nd_array_tensors)
end = time. Time()
print("rt_mod only Time: {:.2f} us". Format(float(end - start)))
# rt_mod only Time: 13.44 us

if we dynamically convert tensor with dlpack, we get around 53us:

# warmup
for _ in range(1000):
    dlpack_tensors = [to_dlpack(torch_tensor) for torch_tensor in torch_tensors]
    tvm_nd_array_tensors = [
        tvm.runtime.ndarray.from_dlpack(to_dlpack(torch_tensor))
        for torch_tensor in torch_tensors
    ]
    rt_mod(*tvm_nd_array_tensors)
start = time.time()
for _ in range(1000000):
    dlpack_tensors = [to_dlpack(torch_tensor) for torch_tensor in torch_tensors]
    tvm_nd_array_tensors = [
        tvm.runtime.ndarray.from_dlpack(dlpack_tensor)
        for dlpack_tensor in dlpack_tensors
    ]
    rt_mod(*tvm_nd_array_tensors)
end = time.time()
print("rt_mod with dlpack Time: {:.2f} us".format(float(end - start)))
# rt_mod with dlpack Time: 53.40 us

Though use a method I mentioned to create a tvm value handle directly from data pointer, we can get around 20us

time_arr = []
for _ in range(100):
    start = time.time()
    for _ in range(1000):
        for i, torch_tensor in enumerate(torch_tensors):
            attr_handle = TVMArrayHandle()
            data = ctypes.cast(torch_tensor.data_ptr(), ctypes.c_void_p)
            check_call(
                _LIB.TVMArrayFromDataPointerOnly(
                    data,
                    device,
                    ctypes.byref(attr_handle),
                )
            )
            values[i].v_handle = ctypes.cast(attr_handle, ctypes.c_void_p)

        check_call(
            _LIB.TVMFuncCall(
                function_handle,
                values,
                tcodes,
                ctypes.c_int(num_args),
                ctypes.byref(ret_val),
                ctypes.byref(ret_tcode),
            )
        )
    torch.cuda.synchronize()
    end = time.time()
    time_arr.append(end - start)
print("Time: ", time_arr)
print("Overall Time{:.2f} us".format(sum(time_arr) / len(time_arr) * 1e3))
# Overall Time 21.97 us

However, it remains 5 times slower than executing the code directly. My analysis suggests that the primary overhead comes from Cython, particularly when using ctypes to convert dynamic python classes (like shape list), and using cython to invoke cdef extern in c lib, both of them will take several us…

To dig more futher, I wrote a cuda source wrapper to compile cuda source into a shared lib, and directly invoke with DLL without the cython related items, looks like the latency is correct:

import time
import ctypes

lib = warpper.load_lib()
torch.cuda.synchronize()
time_arrs = []
for __ in range(100):
    start = time.time()
    for _ in range(1000):
        lib.call(*[ctypes.c_void_p(arr.data_ptr()) for arr in torch_tensors])
    torch.cuda.synchronize()
    end = time.time()
    time_arrs.append(end - start)
print("time arrs: ", time_arrs)
print("lib only Time: ", sum(time_arrs) / len(time_arrs))
print(output_tensor)
# lib only Time:  4.469914436340332  us