I have a question regarding batch normalization in NNVM.
By default (with opt_level=0), it seems that NNM unpacks batch normalization into several unit operators such as add_scalar, sqrt, rdiv_scalar, and elemwise_mul.
Here, If I want to keep batch normalization as a unit operator and fuse operators like @tvm_op(…, …, , func_name=”fuse_conv2d_batch_norm_relu_…", …), what steps do I need to do?
So far I notice that,
- Computation of batch normalization is implemented in Python (./tvm/topi/python/topi/nn/batch_norm.py) as well as C++ (./tvm/topi/include/topi/nn/batch_norm.h).
- The computation and schedule is not registered to NNVM, as seen in ./python/nnvm/top/nn.py) whereas it’s registered to TVM as in ./tvm/topi/src/topi.cc).
- Operator pattern for batch normalization is registered as BROADCAST, as seen in tag scope of batch_norm_inference() in batch_norm.py.
I tried to register batch normalization op to NNVM and execute the code, but it eventually fails in GraphFuseCompile(), specifically when fcompute (essentially, the lambda function of batch normalization code in batch_norm.py) is called in compute() in ./tvm/python/tvm/api.py. Since the fcompute is a lamba function, I’m not sure how I could efficiently debug or trace the code.
Do you have any comments or suggestions? Thanks!
Won