Demo: a simple recursive function in Relay

This is my first shot at writing a Relay program. After some trials and errors, I managed to write a simple recursive function in Relay that simulates a for loop. I’m posting here to share my steps.

  1. Compile TVM with ANTLR enabled.
  2. We will write a program that will add x to acc a given number of times. It will be equivalent to the Python program
def myfun(x, n):
  acc = np.zeros(shape=(2, 2))
  for i in range(n):
    acc += x
  return acc

Since Relay does not support assignments and loops, we will use recursion with a stopping condition:

import numpy as np
from tvm import relay
import tvm
from tvm.contrib import graph_runtime

myrelay = """
v0.0.1

def @myfun(%acc : Tensor[(2, 2), float32],
           %x   : Tensor[(2, 2), float32],
           %n   : int32) -> Tensor[(2, 2), float32] {
  if (%n > 0) {
    @myfun(%acc + %x, %x, %n - 1)
  } else {
    %acc
  }
}
"""

my_relay_module = relay.fromtext(myrelay)

My experience is that it is easier to write the full Relay program as text and use relay.fromtext function. Running print(my_relay_module.astext()) will print

def @myfun(%acc: Tensor[(2, 2), float32],
           %x: Tensor[(2, 2), float32],
           %n: int32)
           -> Tensor[(2, 2), float32] {
  %0 = greater(%n, 0) # ty=bool
  if (%0) {
    %1 = add(%acc, %x) # ty=Tensor[(2, 2), float32]
    %2 = subtract(%n, 1) # ty=int32
    %3 = @myfun(%1, %x, %2) # ty=Tensor[(2, 2), float32]
    %3
  }  else {
    %acc
  }
}
  1. Compile the myfun function using relay.build_module.create_executor():
ctx = tvm.cpu()
opt_level = 3
target = tvm.target.create('llvm')
with relay.build_config(opt_level=opt_level):
  executor = relay.build_module.create_executor('debug', my_relay_module,
                                                ctx, target)

myfun_reified = executor.evaluate(my_relay_module.get_global_var('myfun'))
  1. Now run the compiled function:
data = np.random.uniform(1, 2, size=(2, 2)).astype('float32')
acc = np.zeros(shape=(2, 2), dtype=np.float32)

out = myfun_reified(acc, data, 10).asnumpy()
assert np.allclose(data * 10, out)
5 Likes

Actually sometimes we may find relay text format is easier to write :slight_smile: