[RFC] Add Tensorflow custom op to embed TVM runtime in TensorFlow graph and session

Problem:

TensorFlow is one of the most popular machine learning libraries and most developers are used to train/inference models with TensorFlow/TensorFlow Serving. TVM is the flexible compiler to run computation efficiently in different devices. Although TensorFlow has implemented some efficient GPU operators, developers can benifit from TVM to get more than 10 times speedup and FPGA support. But TensorFlow and TVM have two different code stacks and runtime APIs to use.

There are two ways to integrated TVM with TensorFlow. The first one is tensorflow-to-tvm which has been support by relay importer. Most TensorFlow operators can be “translated” to TVM operators which is useful if want to run the TVM stack with the model structure from other frameworks.

The second one is tvm-to-tensorflow. This requires to embed TVM operators in TensorFlow graph so that we can use TensorFlow session to run preset operators and TVM-optimized operators. This is really helpful if we want to use TVM to optimize part of the computation graph while developers can use TensorFlow Python API to describe the model and use TensorFlow Serving for inference. Embedding TVM in TensorFlow requires the minimal cost to use TVM optimiztion on existing models and extend TensorFlow functionalities such as FPGA support.

This RFC describes how we design to support tvm-to-tensorflow with TensorFlow custom op API and the detail of implementation.

Considerations:

We want to support the complete TVM stack. TVM provides efficient C++ API, easy-to-use Python API to define the kernel scheduling and AutoTVM to search the optimal parameters. We want developers to use the existing tools to define TVM operators and embed the output files instead of re-implementing the same logic in TensorFlow.

We want no develop effort for end users. TensorFlow provides C++ API to define custom op and Python API to load the op with dynamic libraries. We don’t want users to write C++ and Python code for wrapping TVM op as TensorFlow custom op by themselves. We can implement the general C++ TVM runime operator and Python class so that users can use the TVM op in TensorFlow without implementing any TensorFlow custom op.

We want less code change for usage. Sometimes code change is inevitable because we need to specify which op to be replaced. We can use TensorFlow graph edtor API to replace the originl TensorFlow op with TVM op. Since TVM has support some TensorFlow op with the same functionality and may have better performance, we can design the tools like TF-TRT to automatically moditify TensorFlow SavedModel to the optimized one with TVM op. This could be done once we can embed TVM op in TensorFlow graph and replace the TensorFlow op with TVM op.

Proposal:

Now we could not merge the code in TVM codebase but the API should be similar. User can use the TVM stack to build the op and export as dynamic library files. Here is the example code to export the TVM dynamic libries for CPU and GPU.

# CPU
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B')
s = tvm.create_schedule(B.op)
fadd_dylib = tvm.build(s, [A, B], "llvm", name="addone")
dylib_path = os.path.join(base_path, "test_addone_dll.so")
fadd_dylib.export_library(dylib_path)

# GPU
bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
fadd_dylib = tvm.build(s, [A, B], "cuda", name="addone")
dylib_path = os.path.join(base_path, "test_addone_cuda_dll.so")
fadd_dylib.export_library(dylib_path)

Then we can use the pre-built TensorFlow custom op for TVM runtime with Python wrapper. This op is like any other TensorFlow op which can be used in TensorFlow graph and session. There are two options to wrap this Python API.

Option one is to use TensorFlow custom op Python API directly and initialize the op with library file path and function name.

import tensorflow as tf
from tvm.contrib import tf_runtime

with tf.Session() as sess:
  a = tf.constant([10.1, 20.0, 11.2, -30.3])
  b = tf_runtime(a, lib_path="tvm_addone_dll.so", function_name="addone")
  print(sess.run(b))

Option two is to extend the GraphModule in TVM Python API or wrap with new tf_runtime.Module class.

import tensorflow as tf
from tvm.contrib import graph_runtime

mod = graph_runtime.create(graph, lib, ctx)
addone = mod["addone"]

with tf.Session() as sess:
  a = tf.constant([10.1, 20.0, 11.2, -30.3])
  b = addone(a)
  print(sess.run(b))

However, we have to call underlay TensorFlow API to load the custom op and return the tensor object. The GraphModule seems to be the Python bridge for C++ and run TVM op directly instead of being the standard TensorFlow op to run by TensorFlow session. It is okay to wrap TensorFlow custom op with any Python class and try to match the same usage of other TensorFlow op.

The TensorFlow custom op for TVM runtime can be implemented by combining TVM Runtime C++ API and TensorFlow custom op C++ API. We have the implementation and examples in https://github.com/tobegit3hub/tftvm which can be moved to tvm.contrib once the API is determined. Here is the code the register the TensorFlow custom op and it requires lib_path and function_name to load the TVM dynamic libraries. Moreover, TVM Runtime API requires to know the dtype and shape of the input tensors, these messages can be passed either by TensorFlow op attr or load from TVM dynamic libraries.

REGISTER_OP("TvmRuntime")
    .Attr("lib_path: string")
    .Attr("function_name: string")
    .Input("tvm_input: float")
    .Output("tvm_output: float");

Then we can implement the CPU and GPU kernel with TVM Runtime C++ API. We can load dynamic libraries with attribute parameters when initializing the op. For each op process, overwrite the Compute method to use TVM Runtime API for computation and read/write data as TensorFlow tensor.

void Compute(OpKernelContext* context) override {
  int device_type = TvmRuntimeOpTrait<DEVICE_TYPE>::device_type;
  int device_id = TvmRuntimeOpTrait<DEVICE_TYPE>::device_id(context);
  int64_t shape[1] = {10};
  TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes,
                device_type, device_id, &x);
  TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes,
                device_type, device_id, &y);

  // Get input tensor
  auto input = input_tensor.flat<float>();
  x->data = const_cast<float*>(input.data());
  const int input_size = input.size();

  // TVM run
  tvm_func(x, y);

  // To output tensor
  Tensor* output_tensor = NULL;
  OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                    &output_tensor));
  auto output_flat = output_tensor->flat<float>();
  memcpy(output_flat.data(), y->data, input_size*4); 
  //cudaMemcpy(output_flat.data(), y->data, input_size*4, cudaMemcpyDeviceToDevice);
}

Notice that the TensorFlow custom op can be built just once and be used to load different TVM operaters. Developers need TensorFlow environment to build this custom op with g++ or bazel.

Finally, we can use TensorFlow graph editor API to replace original TensorFlow op with TVM-optimized custom op and automatically TensorFlow SavedModel convertors. These are less important than the previous TensorFlow custom op and we may discuss that once the fundamental functionalities are ready.

Related discussion is in Add TensorFlow custom op and run tvm in TensorFlow .

1 Like

I think that this is a great idea, and I overall agree with the implementation. Implementation 2 is a little cleaner IMO.

Regarding implementation 2: I am currently working on a change to use original model runtimes like TensorFlow in an executing graph. The RFC is here. The idea is to run operators that aren’t supported by TVM in their original runtime. Personally, I feel that the runtime that runs TensorFlow directly should be called tf_runtime. Maybe the runtime that you are proposing could be called tf_tvm_op_runtime?

I don’t see too much overlap between these runtimes, although there will be duplicated work (such as linking with TensorFlow).

Thanks @jonso . You’re absolutely right.

TensorFlow custom op should be implemented in C++. We need to load dynamic libraries with Module/PackedFunc in C++ and use TensorFlow Python API to load TensorFlow custom op. It is a little duplicated to TVM Python API but we could not use TVM Python graph_runtime and GraphModule directly.

For TensorFlow custom op, we need to set attributes for each tensor such as tf_runtime(a, lib_path="tvm_addone_dll.so", function_name="addone"). If we don’t want to set lib_path for multiple times, we may wrap with another class to initialize these firstly. The new wrapper class could be named tf_runtime.Module but it could be much different from the existing graph_runtime.GraphModule.

@jonso suggests the new way to embed the whole TensorFlow session in TVM op and it would be named as tf_runtime.

This proposal will generated the TensorFlow custom op with TVM runtime to be used by TensorFlow scripts. So we may rename this as tvm.contrib.tf_op and the API could be much more user-friendly.

from tvm.contrib.tf_op import TvmModule

mod = TvmModule("test_addone_cuda_dll.so")
addone = mod.func("addone")

with tf.Session() as sess:
  a = tf.constant([10.1, 20.0, 11.2, -30.3])
  b = addone(a)

One thing to note though is that in order to enable of TF’s remote driving capability, we cannot assume that the python session is the same as the running one.

So we may not be able to directly use the TVMModule way(which is more like an eager mode call on the same process as the running python one), because the dll needs to be loaded from remote.

Although I am not sure how frequent is such remote driving capability is being used atm, i can certainly see that the remote driving could be important in certain ways for distributed training, but other ways of distributed training may not need this capability.

The TVMRuntime custom op registration will be able to resolve this remote loading usecase(as long as the compiled lib is shipped with the tf). It would be great to improve the custom op implementation to be a bit layered.

TVMPackedFuncOp(base class, takes in a PackedFunc and implements compute)
TVMDSOOp : TVMPackedFuncOp(takes in dll, and function name, caches the dll in a singleton so a dll loads once, and takes the PackedFunc)

Additional notes on API: we can still do some wrappings in the option1

from tvm.contrib.tf_op as tfo

# Note just wraps things in python
mod = tfo.Module("tvm_add_one_dll.so")
addone = mod["addone"]

with tf.Session() as sess:
  a = tf.constant([10.1, 20.0, 11.2, -30.3])
  # constructs tvm_dso_op under the hood
  b = addone(a)
  print(sess.run(b))

In terms of implementation. The current API still has a malloc and copy to move data into the TVM. We should instead only construct a DLTensor object and use zero-copy to direct make use of memory in TF for maximum efficiency

@tobegit3hub it would be great if you can summarize the current discussions and update the RFC proposal :slight_smile:

Thanks @tqchen and we will update the proposal and implementation soon.

It seems I can not edit the original post now. I may update the design and implementation here and everyone can try this feature with latest code in https://github.com/tobegit3hub/tftvm .

@tqchen has suggest TVMPackedFuncOp and TVMDSOOp which is great and we have rename the TensorFlow custom op class to TVMDSOOp. TVMPackedFuncOp is reasonable to take PackedFunc as input but it is not suitable to be the TensorFlow operator. Because users can not pass the PackedFunc object though TensorFlow operator attributes, we do not have TVMPackedFuncOp now but it could be the parent class of TVMDSOOp if needed.

We have designed two major components for this proposal. The first one is the TensorFlow custom op which is TVMDSOOp now. Currently it supports loading any TVM Module and use CPU and GPU for computing. We have updated the register class to support more attributes so that users can specify the output shape/dtype by themselves(actually it would be better if we can get these information from TVM dynamic files) and still need the implementation of zero-copy(it could be great if someone could help). However, TensorFlow requires custom op to register the number of input tensor. TVMDSOOp is the general op for TVM which should support one or more inputs, we may export multiple operators like TVMDSOOp but users do not need to know about this with the Python wrapper API. There is no magic in the implementation of TVMDSOOp and everybody can try to use now.

class TVMDSOOp : public OpKernel {
  void Compute(OpKernelContext* context) override {
    // 1. Get dynamic file paths from operator context
    // 2. Load PackedFunc with TVM API
    // 3. Run the TVM op and return result as TensorFlow tensor
  }
}

Another major component of this proposal is the Python API. We agree to use option 2 and follow @tqchen’s suggestion to make the API more readable. Currently it is like the TVM Python runtime but the new Module and Func will return the TensorFlow tensor instead of running directly.

import tensorflow as tf
from tvm.contrib import tf_op

mod = tf_op.Module("tvm_addone_dll.so")
addone = mod["addone"]

with tf.Session() as sess:
  a = tf.constant([10.1, 20.0, 11.2, -30.3])
  b = addone(a)
  print(sess.run(b))

These Python package can be link to tf.contrib and everyone can try after setting up the environment with the tftvm project. There is no magic in implementation of Python wrapper, too. We use TensorFlow Python library to load TVMDSOOp and get input tensors to output result.

With the TVMDSOOp in C++ and tf_op.Module in Python, developers can embed TVM operators in TensorFlow easily and no need to develop the TensorFlow custom op by themselves. The future work we need to do is making the TVMDSOOp more general for supporting more devices and stablizing the users’ API.

Thanks everyone’s effort and update. I think the proposal is ready to be presented as a formal RFC and PR. @tobegit3hub can you send a formal RFC in the github issue and send a corresponding PR so that we can start the review process?

Thank you!

Thanks @tqchen . All the features have been implemented and the formal RFC and pull-request are in progress. We will create the pull-request from our official repository soon :slight_smile:

Hi @tobegit3hub

We have considered embedding TVM runtime into TF as a custom op previously.

However, due to resource limitation, we haven’t made so much progress yet.

It is great to see that your work is almost close to be submit to the community.

And really look forward to your PR.

Thanks

Jun

The formal RFC is in https://github.com/apache/incubator-tvm/issues/4464 .

And the Pull-request is in https://github.com/apache/incubator-tvm/pull/4459 .

Is there some code examples for this topic?