Why FoldConstant optimization needs schedule ops?

Hi,I’m a newer to TVM stack. When I debug the process of the relay build, I found FoldConstant optimization would entry the ScheduleOps func. Why? Does the optimization of the relay build lie on graph-level? Or graph-level and op-level all done? Thanks very much.

FoldConstant identifies sections of Relay graph that can be precomputed at compiler time. Basically, it extracts subgraphs from the original graph that can be precomputed. Each of these subgraphs is then compiled using Relay (this is where it goes through schedule) and then executed to get the new constant. Therefore, whole compilation and runtime tool chain is called for each subgraph. Hope this helps.

1 Like

Thanks for reply. The BulidRelay Func includes severial steps, like Optimize, GetLoweredFunc, and then calls tvm.build. FoldConstant optimization(extracte some subgraphs) is whole compiled in the step of Optimize. But other subgraphs(not be extracted by FoldConstant) will be compiled in all steps of relay build. I confused that Optimize seems to not only have “graph optimization”, but also ScheduleOps and codegen. Can these subgraphs extracted by FoldConstant be excuted in the outside step of Optimize(such as tvm.build)? Not sure I descriped it clearly. Thanks a lot.

I’m stumbled by the same confusion when tracing the FoldConstant optimization pass. Some of my debug print shows the process as below.

I don’t know why (anyone, please explain if you know) but Relay kicks the ‘whole compilation process’ by using the Interpreter (interpreter.cc).

  • ‘EtaExpand’ Pass is explicitly listed out in FoldConstant its own Sequential pass.
  • ‘FuseOps’ Pass is triggered by ConstEvaluate(expr) function in ConstantFolder class.
  • ‘InferType’ Pass is triggered as a dependent pass from ‘FuseOps’.
  • Then CompileEngineImpl::JIT(key) is called to do JIT compilation which kicks off backend lower_call() to select from TOPI schedule implementations.
  • The process goes on with lower level IR passes, such as tir.ThreadSync, tir.SplitHostDevice etc.
  • It finally goes through CodeGen process as well to finish the whole pipeline.

According to Relay document, the ‘Interpreter’ is for ‘debug’ purpose mainly, or a quick and dirty implementation.

transform.cc SequentialNode::operator(), pass name:FoldConstant
transform.cc SequentialNode::operator(), pass name:EtaExpand
transform.cc SequentialNode::operator(), pass name:FuseOps
transform.cc SequentialNode::operator(), resolved dependency pass name:InferType
transform.cc SequentialNode::operator(), pass name:InferType
interpreter.cc VisitExpr_(CallNode*): Invoke() -> calls JIT(key)
CompileEngineImpl::JIT(key)
Inside compile_engine.cc VisitExpr_(CallNode)
 Calling into Python relay.backend.lower_call()
tvm/python/tvm/relay/backend/compile_engine.py, select_implementation(), op.name= multiply
  valid implementation  0 :  injective.cpu plevel= 10
  selected best_plevel_implementation:  injective.cpu
Use implementation injective.cpu for op multiply
tvm/python/tvm/relay/backend/_backend.py: lower function:  fused_multiply
lower phase 0
lower phase 1
lower phase 2
lower phase 3
produce T_multiply {
  T_multiply[ramp(0, 1, 16)] = (x16(placeholder[0])*placeholder[ramp(0, 1, 16)])
}

transform.cc SequentialNode::operator(), pass name:_transform
transform.cc SequentialNode::operator(), pass name:tir.ThreadSync
transform.cc SequentialNode::operator(), pass name:tir.ThreadSync
transform.cc SequentialNode::operator(), pass name:tir.InferFragment
transform.cc SequentialNode::operator(), pass name:tir.LowerThreadAllreduce
transform.cc SequentialNode::operator(), pass name:tir.BindDeviceType
transform.cc SequentialNode::operator(), pass name:tir.SplitHostDevice
transform.cc SequentialNode::operator(), pass name:_transform
transform.cc SequentialNode::operator(), pass name:tir.LowerWarpMemory
transform.cc SequentialNode::operator(), pass name:tir.LowerDeviceStorageAccessInfo
transform.cc SequentialNode::operator(), pass name:tir.LowerIntrin
transform.cc SequentialNode::operator(), pass name:_transform
transform.cc SequentialNode::operator(), pass name:_transform
transform.cc SequentialNode::operator(), pass name:tir.LowerTVMBuiltin
transform.cc SequentialNode::operator(), pass name:tir.LowerDeviceStorageAccessInfo
transform.cc SequentialNode::operator(), pass name:tir.LowerIntrin
transform.cc SequentialNode::operator(), pass name:tir.CombineContextCall
runtime::Module Build(): target.build.llvm

cause the TVM does’t contain implementation except TIR implemenations. As we know, fold constant need to using the operator to caculate the value to fold. So the only choose is to use TIR schedule compile a x86 program, then compute it.