Add TensorFlow custom op and run tvm in TensorFlow

Now tvm has supported most of TensorFlow op which can to load with tvm.relay. Users can convert TensorFlow op and run with tvm runtime but sometimes we want to run with TensorFlow session with some optimized tvm op.

It is possible to wrap tvm op as TensorFlow custom op so that tvm op can be part of TensorFlow graph. The user interfaces may be like TF-TensorRT which has TensorFlow custom op and graph editor to run optimized TensorFlow model without effort.

2 Likes

TVM runtime is designed on top of DLPack, as long as TensorFlow supports DLPack, we might be able to convert operators back and forth. PyTorch, Chainer and MXNet has been part of DLPack, but I am not sure if TF has a timeline for this…

CC @yzh119 if DGL team share the same interest

Thanks for your reply @junrushao . We have the new project which can wrap TVM Runtime API with TensorFlow custom op without native DLPack support https://github.com/tobegit3hub/tftvm . It is like TF-TRT and end users can run TVM op with TensorFlow graph and session easily.

3 Likes

The implementation looks great. Would you like to do a bit more API design proposal and contribute it back?

In particular, i can see a few places that could be improved:

  • Look into TF’s memory alignment and other status, do zero copy by creating a DLTensor instead of memcpy
  • Have better ways to specify the modules and function inside the module
  • Have a good calling convention.

This is great! Looking forward to the rfc and discussion.

I think it is a proof-of-concept which is worth more attention - the use cases are ubiquitous and we could imagine that people are absolutely interested.

  • to-tvm: TVM already has importers from many frameworks (TF, ONNX, MXNet, etc).
  • from-tvm: method to embed TVM into other systems (TF, PyTorch) is still in lack.

So I would say “from-tvm” is probably worth trying and may have great impact.

Thanks for all you interest. We have updated the code to support better way to load TVM libraries base on @tqchen 's comment. Users can load TVM op just like TensorFlow op without effort. We have scripts of graph editor to replace original TensorFlow op with TVM op as well.

import tensorflow as tf
from tvm_runtime import tvm_runtime

with tf.Session() as sess:
  a = tf.constant([10.1, 20.0, 11.2, -30.3])
  b = tvm_runtime(a, lib_path="tvm_addone_dll.so", function_name="addone")
  print(sess.run(b))

It would be great if someone can help us to draft the RFC together and push this to TVM community. The code is open source and any feedback is welcome.

2 Likes

If could be great if you can summarize your key API decisions, and high level implementation details as an RFC.

Some high level thoughts

  • Put code into tvm/src/contrib/tf_runtime
  • API could use some additional wrapping to make it more user friendly
import tensorflow as tf
from tvm.contrib import  tf_runtime

mod = tf_runtime.Module(lib_path="tvm_add_one_dll.so")
addone = mod["addone"]

with tf.Session() as sess:
  a = tf.constant([10.1, 20.0, 11.2, -30.3])
  b = addone(a)
  print(sess.run(b))

In terms of implementation. It would be great if we have a class that just takes in PackedFunc, and implements the bridge. Then have sub-classes that takes in a Module. This way we could reuse the same infra for things like graph_runtime, which can then embed into a PaxckedFunc.

It would be great if you can spend a bit time on zero copy

@tobegit3hub are you currently working this? Specifically, are you working on tf runtime support in tvm/src/contrib/tf_runtime?

Thanks @tqchen and @jonso . We are working on the drafting the RFC and it will cover the tf_runtime API which was mentioned above.

It may takes some time because we are new to TVM and the community. It would be great if committers could help and we may discuss the detail about the core and API design with you soon.

It could be great if we can get knowledge of function signatures like shape and type from module. Are their ways to extract such “attributes” of exported modules?
For example, if we want to implement zero copy from TF tensors, we may have to ensure no inplace computation happens for current packed function.

The generated tvm code already ensures no inplace will happen in terms of arguments being passed in

The better-formatted RFC is in [RFC] Add Tensorflow custom op to embed TVM runtime in TensorFlow graph and session, feel free to comment and discuss in that post.

2 Likes

Hi @tobegit3hub @tqchen ,

We figure out the way to avoid copy in tensorflow. Detail could be found at https://github.com/VoVAllen/tf-dlpack/issues/3

3 Likes

Thanks @VoVAllen. It would be great if we can leverage the implementation from tf-dlpack to transfer Tensor to DLPack. Can we use the C++ library to convert Tensor to DLPack instead of using another TensorFlow custom op?

@tobegit3hub It’s possible but would be hard. Tensorflow 2.0 did a lot of conversion to preserve the compatibility between symbolic mode and eager mode. And those API doesn’t seems stable. I’m not sure how to get the correct C++ object at C++ part. A good news is that tensorflow is migrating from swig to pybind11. This would make the C++ API much simpler.

What do you mean by C++ library instead of custom op? What’s the difference between them or how would you expect the library to be?

Hi @VoVAllen, we have read the source code of tf-dlpack and it provides some custom TensorFlow operators to convert to/from TensorFlow Tensor.

However, it seems not handle the situation if the memory is not aligned. We have another implementation to copy to data to new aligned pointer in https://github.com/tobegit3hub/tftvm/pull/4/files . It may be inefficient and it is alway not aligned. Do you have any solution to solve this?

  void *AllocateRaw(size_t alignment, size_t num_bytes) {
    if (num_elements_ * (dlm_tensor_->dl_tensor.dtype.bits) / 8 != num_bytes) {
      std::cout << "Invalid allocation bytes" << std::endl;
    }
    auto iptr = reinterpret_cast<std::uintptr_t>(data_);
    if (!(iptr % alignment)) {
      std::cout << "Memory not aligned" << std::endl;
    }
    return data_;
  }

Hi @tobegit3hub ,

So far we didn’t find the solution other than copy to ensure alignment. However, I don’t think your case will have this issue, since input tensor and output tensor in your case are both allocated by tensorflow. And the alignment requirement of input tensor should be required by the tvm function but not tf kernel right?

Unless the tensor is allocated by tvm function(actually you can also ensure alignment when allocating in tvm), otherwise tf allocated tensors are always aligned. In my case, the tensor sometimes is sliced from another tensor without copy. Thus its address won’t be aligned.

Thanks @VoVAllen and you are right. The default allocator of TensorFlow seems to create the Tensor with the same alignment for DLPack. We will check the alignment and copy the data only if they are not aligned.