Parallel prefix scan in TVM

I was wondering if it is possible to schedule parallel prefix scan in TVM’s Halide IR.

The question comes from implementing numpy operators like numpy.cumsum. Also, all the ufuncs (like broadcast_add) has the interface ufunc.accumulate that allows users to do prefix scan with ease.

For now, TVM has the primitive, tvm.scan, which is semantically equivalent to prefix scan. However, scheduling loops containing scan.idx is not as easy compared with other loops, because of its sequential dependency.

If we do care about efficiency, what we probably need is to do reduction on a binary tree, like many textbooks may mention (like this). However, it does change the semantics of loops, so I suspect that it is hard to implement using TVM’s Halide IR.

Any ideas?

1 Like

@Laurawly @were could you share your ideas about this?

@junrushao I implemented a simplified parallel prefix sum in this PR: https://github.com/dmlc/tvm/pull/2784
It’s a sub op used by get_valid_counts: https://github.com/dmlc/tvm/blob/37630b848cfe51ae1ec3d42c4c0b01489a029544/topi/python/topi/cuda/nms.py#L83

It’s not fully optimized like state-of-the-art paper or cub library. But it improves my previous sequential implementation a lot.

1 Like