Explore Optimizations for Concat


#1

There are several observations by @antinucleon @vinx13 @kevinthesun suggesting that the current way of handling concatenation might not be optimal. Specifically, currently, we try to fuse as much as possible and still tries to use the same data parallel code generators to generate concat. This will result in if_then_else or switch expressions that are not necessarily the fastest. Interestingly some of these could become bottlenecks.

We can at least come up with a few alternatives:

  • Mark concat as opaque and directly generate code that copies into the target region
    • Skip concat via no-op and see how much difference we can get
  • Special handle concat, by making use of Buffer bind semantics to especially generate a number of kernels that directs copies into the target region.

This thread is for some discussions as well as possible experimental results people could provide to see how expensive concat are and what gains we can get by using these alternatives.


#2

Thanks for the post. I also experienced similar problems with concat.

Is it somehow possible to introduce a TensorView kind of abstraction, such that we can pass on a subset of Tensor instead of creating a new space for the concat operator. I am thinking also for other memory operators like reshape, expand_dims, squeeze etc that do not change the memory contents, but only change the way we look/read the data contents for the subsequent operator.


#3

Thanks for raising this. Currently solution of Concat is not ideal due to its recursive nature. And may result in stack overflow if number of inputs is large. I saw repeating stacktrace patterns like follows for roughly each input:

#166 0x000000000c8f077e in std::function<HalideIR::Internal::Stmt (HalideIR::Internal::LetStmt const*, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*)>::operator()(HalideIR::Internal::LetStmt const*, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*) const (this=0x7ffff211a420, __args#0=0x7fff3700f020, __args#1=..., __args#2=0x7fff745ea540) at ../libgcc/include/c++/7.3.0/bits/std_function.h:706 
#167 0x000000000c8ec4f7 in tvm::IRFunctor<HalideIR::Internal::Stmt (tvm::NodeRef const&, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*)>::set_dispatch<HalideIR::Internal::LetStmt>(std::function<HalideIR::Internal::Stmt (HalideIR::Internal::LetStmt const*, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*)>)::{lambda(tvm::NodeRef const&, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*)#1}::operator()(tvm::NodeRef const&, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*) const (this=0x7ffff211a420, n=..., args#0=..., args#1=0x7fff745ea540) at tvm/tvm/3rdparty/HalideIR/src/tvm/node/ir_functor.h:108 
#168 0x000000000c8f9ef3 in std::_Function_handler<HalideIR::Internal::Stmt (tvm::NodeRef const&, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*), tvm::IRFunctor<HalideIR::Internal::Stmt (tvm::NodeRef const&, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*)>::set_dispatch<HalideIR::Internal::LetStmt>(std::function<HalideIR::Internal::Stmt (HalideIR::Internal::LetStmt const*, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*)>)::{lambda(tvm::NodeRef const&, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*)#1}>::_M_invoke(std::_Any_data const&, tvm::NodeRef const&, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*&&) (__functor=..., __args#0=..., __args#1=..., __args#2=@0x7fff74502f18: 0x7fff745ea540) at ../libgcc/include/c++/7.3.0/bits/std_function.h:302 
#169 0x000000000c636b74 in std::function<HalideIR::Internal::Stmt (tvm::NodeRef const&, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*)>::operator()(tvm::NodeRef const&, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*) const (this=0x7ffff2031ba0, __args#0=..., __args#1=..., __args#2=0x7fff745ea540) at ../libgcc/include/c++/7.3.0/bits/std_function.h:706 
#170 0x000000000c63661d in tvm::IRFunctor<HalideIR::Internal::Stmt (tvm::NodeRef const&, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*)>::operator()(tvm::NodeRef const&, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*) const (this=0x2a97f0b0 <tvm::ir::IRMutator::vtable_stmt()::inst>, n=..., args#0=..., args#1=0x7fff745ea540) at tvm/tvm/3rdparty/HalideIR/src/tvm/node/ir_functor.h:76 
#171 0x000000000c634745 in tvm::ir::IRMutator::Mutate (this=0x7fff745ea540, stmt=...) at tvm/tvm/include/tvm/ir_mutator.h:44 #172 0x000000000ca52fa8 in tvm::ir::IRUseDefAnalysis::Mutate_ (this=0x7fff745ea540, op=0x7fff3700f050, s=...) at tvm/tvm/src/pass/split_host_device.cc:53 #173 0x0000000012e296ee in tvm::ir::<lambda(const HalideIR::Internal::LetStmt*, const HalideIR::Internal::Stmt&, tvm::ir::IRMutator*)>::operator()(const HalideIR::Internal::LetStmt *, const HalideIR::Internal::Stmt &, tvm::ir::IRMutator *) const (__closure=0x7ffff211a420, op=0x7fff3700f050, s=..., m=0x7fff745ea540) at tvm/tvm/src/pass/ir_mutator.cc:310 
#174 0x0000000012e2eb47 in std::_Function_handler<HalideIR::Internal::Stmt(const HalideIR::Internal::LetStmt*, const HalideIR::Internal::Stmt&, tvm::ir::IRMutator*), tvm::ir::<lambda(const HalideIR::Internal::LetStmt*, const HalideIR::Internal::Stmt&, tvm::ir::IRMutator*)> >::_M_invoke(const std::_Any_data &, const HalideIR::Internal::LetStmt *&&, const HalideIR::Internal::Stmt &, tvm::ir::IRMutator *&&) (__functor=..., __args#0=@0x7fff74503158: 0x7fff3700f050, __args#1=..., __args#2=@0x7fff74503148: 0x7fff745ea540) at ../libgcc/include/c++/7.3.0/bits/std_function.h:302 
#175 0x000000000c8f077e in std::function<HalideIR::Internal::Stmt (HalideIR::Internal::LetStmt const*, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*)>::operator()(HalideIR::Internal::LetStmt const*, HalideIR::Internal::Stmt const&, tvm::ir::IRMutator*) const (this=0x7ffff211a420, __args#0=0x7fff3700f050, __args#1=..., __args#2=0x7fff745ea540) at ../libgcc/include/c++/7.3.0/bits/std_function.h:706

Mark concat as opaque and directly generate code that copies into the target region

Seems a good candidate solution to me.


#4

Agree that we need some benchmarking to decide the best solution.


#5

would be great if we can get some volunteers to look into these possibilities :slight_smile:


#6

Concatenation op makes gluoncv SSD hangs for a long time (more than 5 minutes) before returning the results on Mali GPU. I made the concatenation ops opaque and the program finished in around 400 ms.


#7

I have some new thoughts on how we could attack concat. Ideally, we want to generate several for loops(equal to the length of the input tuple) instead of one for loop and use selection. Here is how we possibly achieve this through code transformation.

Specifically, we could introduce an intrinsics tvm_axis_switch (name can be discussed), and we can have loops like

for (i = 0; i < 100; ++i) {
   B[i] = tvm_axis_switch(i, 0, 20, 40, 100, A0[i], A1[i-20], A2[i-40])
} 

The semantics is pretty clear, we are trying to concat A0,A1, A2, and tvm_axis_switch indicate that we are trying to switch on a possibly loop variable i, and try to look into the corresponding ranges.

Then we write a pass to SplotAxisSwitch, which try to pattern match this loop pattern, and split the loop into several ones. Of course, to keep things simple, we could require that i is indeed a loop variable and the range matches the range of the serial loop. If the pattern detection failed, we fall back to if_then_else

That means we need to have a special OpPattern for concat(InputFusableOutputElemwiseFusable), which allows fuse of injective ops in the input, and elementwise op in the output. And a special schedule that leave the axis of concat alone(so it is a simple loop and allows the followup optimization).


Tvm 's detection speed is so slower than Mxnet on SSD-Mobilenet
#8

cc @hlu1 @ajtulloch who might also be interested. Would also like to see everyone’s thoughts and if anyone is interested in taking a stab on this


#9

We have found a simple workaround in the case of concatenating 2D tensors (currently our most common use case). By unrolling the last axis, llvm is smart enough to generate vectorized code and the performance is even better than c code in caffe2. For benchmark numbers, see https://gist.github.com/ajtulloch/d3b47517721c71c09375fd76f387e718 from @ajtulloch.


#10

would be nice if you have bandwidth to follow up on this :slight_smile:


#11

Sounds good. Will do.


#12

One disadvantage of unrolling concat is that it can increase the compilation time significantly if there are a lot of concats with many inputs. On one model we tested, it takes about 10min, instead of several seconds without unrolling. Basically we’re trading off compilation time for run time. @tqchen’s tvm_axis_switch approach might be a better alternative overall.


#13

Yes. I also meet this problem. One workaround maybe is to limit the max_unroll be 16 like our conv2d.py on ARM CPU.


#14

According to my test, simply set opaque for concat could also work well.


#15

@hlu1 @tqchen Do we have some updates on this?


#16

We decided to go with a memcpy based version (with tvm.extern) internally because it’s simple and works for our use cases (single threaded inference). I’m happy to upstream the implementation if it’s useful for the community.


#17

We might want to avoid memcpy for different devices types like GPU have different memcpy API. We could start with a basic extern scripted for loops that works for both GPU and CPU, then move on to support the axis_switch based version


#18

I will try to implement the axis_switch solution.