[RFC] Handling Effect in TVM and Relay

I don’t think we need to complicate the issue here. In particular wrt to threadsafety and atomic update(we don’t have to introduce atomic update).

Note that the final goal(which has explicitly pass the state around won’t have this problem). And the program semantics that users write(which relies on a PRNG state, that we can define as global to the module or thread local) has the problem.

The main question is how easy it is to convert the program that uses the PRNG state to the explicit passing style. After that conversion, the execution engine can then decide how to implement the PRNG.

I agree with automic update :slight_smile:

But the essential problem here is that we are adding a virtual dependency between operators who have the same effect. For example, in a Relay function, there might be two dropouts. If we force them to share the same global PRNG, those two dropouts will be serialized.

We could insert ref read/write around the operators who have effects - this ensures correctness. Then as Marisa mentioned, we could rely on partial evaluation to optimize most of them out.

What is happening here is we represented a concept (nondeterminism) with it’s implementation (PRNG). The principled solution is to add nondeterminism into Relay, then we can do optimization on the relay level, knowing two nondet function can be swapped.

1 Like

Per offline discussion with @tqchen. Let’s limit the scope to the usage of PRNG states.

Solution A1. Allocating single global reference. As suggested by @MarisaKirisame, we can maintain a global state inside a module, which indicates the PRNG key. Every time we encounter a stateful operator f, we insert a RefRead before f, whose result is passed to f as an additional argument, and then put a RefWrite after f.

Example.

fn my_func(x) {
    let y = Dropout(x);
    y
}

will be transformed to

fn my_func(x) {
    let prng_key = RefRead(global_prng_state);
    let (y, new_prng_key) = Dropout(x, prng_key);
    RefWrite(global_prng_state, new_prng_key);
    y
}

This transformation should be put on the earliest stage to ensure correctness to make sure all our existing passes functioning properly.

Solution A2. Passing states everywhere. We can also add an extra argument called prng_state to all functions (including closures), and add an extra return item as well. The newly generated PRNG state will shadow the old ones instead of using mutation. Note that in this case, we should handle if expressions properly.

Example.

fn my_func(x) {
    let y = Dropout(x);
    y
}

will be transformed to

fn my_func(x, prng_state) {
    let (y, prng_state) = Dropout(x, prng_state);
    (y, prng_state)
}

The extra argument could be eliminated if and only if all callees in the call-graph do not use it.

From Solution A1 to A2. There is indeed a connection between A1 and A2. In many cases, using partial evaluation will reduce Solution A1 to A2.

A2 is essentially monad. While it might be a good idea here, I do not think it is a good solution to the effect problem.

Monad is a programming trick in Haskell or other purely functional programming language, that is pure and must respect referential transparency (calling the same function twice has the same result). Monad encapsulate effect by, in a sense, creating a sub-language that can do whatever the Metalanguage(Haskell) do, with api for certain impure effect.

For example, here is the monad that correspond to arbitrary state:

State a b = --have a global variable a, and eventually return b
   ReturnState b -- a b value in Haskell
|  Read (a -> State a b) -- read the global variable, then continue executing with that value
|  Write a (State a b) -- write to the global variable, then continue executing

However, while the outer language is pure, the inner language (State, in this example) still has effect, and many pass need to do effect-specific optimization (e.g. if two write do not alias, you can swap them). Operating on the pure fragment merely let us build up this monad purely, but generally cannot peak into the structure. Using Monad Also bring tons of programming problem (what happend if you have two monad and you want to combine them? this is a very hard problem in haskell, and ppl are still designing new solution to these day)

The more principled solution IMO, is to add effect directly into relay, and add optimizations on those level. If the above A2 solution is desired, we can add a “local variable” effect, which should be easier to optimize.

2 Likes

@jroesch got some times to hop in?

Hey Marisa,

You are right, it is essentially (very simplified) monad - this is why I limit the discussion scope only to PRNG, and why I say “in many cases” A1 can be reduced to A2, but not all cases. Also I agree with you that bringing in full monad brings us tons of problems, and it is impractical for now.

To make it practical, let’s simplify the problem by limiting the scope. Let’s step back and look at what the problem looks like: contributors of frontend importers want to use 1) PRNGs; 2) batch normalization (cuDNN’s implementation contains inplace mutation, so many libraries they just follow this design unfortunately).

For PRNG, just to ensure reproducibility (or referential transparency to be exact).

For batch normalization, from my point of view, it is more correct just to make it as pure as like: BatchNorm(x, running_mean, running_var) -> (y, new_running_mean, new_running_var). But we need to make sure frontend importer work this way.

How does tvm do with inplace mutation in general btw? Does tvm has plan to support inplace? There should be some optimization done to turn Ref of Tensor update into the internal value update (if possible). I think mxnet does this and we should too.

MXNet supports inplace mutation. To deal with races, it has an almost general-purpose dependency engine, which builds a dependency graph inside…https://mxnet.apache.org/api/architecture/note_engine

Immediate idea: if this is costly, we can use static analysis to trace out raceless case. But I assume nothing’s cost come close to calling a gpu kernel…sigh

Also, in the above code, we can write a pass which can gurantee to turn A1 to A2 - by essentially injecting ‘monad’ into the program. This is still transparent to the end user though.

1 Like

One thing worth noting is that we are not trying to optimize for general purpose programs, but deep learning programs that happens to have (a few) state updates due to PRNG, parameter update and IO.

We should bear in mind is that most developers does not know how to deal with states, and most optimizations do not need to deal with them. So in an ideal case we want to isolate the pure blocks into basic blocks, as in Basic Block Normal Form and not touching state updates in most cases.

While it is possible to do alias analysis and move mutation around, we have to admit that the cost of writing compilers passes for this style of program(as in A1) is higher. The explicit passing style(monad, A2), on the other hand, brings a larger region of “pure programs”, which means the basic block optimizers could likely bring more benefits(by being able to reorder other ops with random, for example). The explicit passing style is also more friendly to be transformed into execution graph that runs on multiple machines(devices), in which case we would need to pass the state around.

On the other hand, from the frontend writer’s perspective, it could be tricky to write all state update in the monad style, and the program in A1 is much easier to write due to the fact that not having to pass the states around.

So to summarize some key points:

  • It is relatively easy to write optimizations for pure basic blocks and we want to enable that(by bringing bigger basic blocks when possible).
  • The state mutation program breaks basic blocks apart(because we want to make sure basic blocks remain pure), so the program in A1 is less friendly for optimization (of course alias analysis would help, but that adds complexity to pass development). Explicit state passing style in a bigger block is more desirable for optimization(of course we could still use state read/write as source/sink of the state passing)
  • On the other hand, from the frontend perspective, most programs are written in the A1 style.

Considering the above points, I think it is important to:

  • Design interface that explicitly passes state around(while allow ref to record state update)
  • Design rewriting passes to rewrite A1 style program into A2
    • The rough idea is to collect all the affected functions, and their states, add the state as function argument and return values.
    • Note that the rewriting can be done partially, which only removes some of the state mutations, but still keeps state read/write at an upper level source sink.
  • Encourage users to write basic block optimizations while dealing with functions that explicitly passes the state
2 Likes

For Relay/TVM, are we mainly concerned with computer operators (random number generators), or we want to support IO input / IO output as well?

I think there is no much difference in this case

I took a look at Jax’s approach, just want to summarize and share below:

  1. A random number function requires an explicit random state (a “key”) as one of the arguments. Calling the random number does not change the key/random state:

nums = random.normal(key, shape=(1,))

  1. To get a different random state, users need to call the split function explicitly:
_, subkey = random.split(key); 
different_nums = random.normal(subkey, shape=(1,))
  1. In a training program, the keys are split and replaced outside the JIT program in python. The JIT program just takes all new keys as inputs and they are immutable.

Since random.normal does not change the state, it implies most off-the-shelf random number generator cannot be plugged in. Jax designed their own PRNG module and implementation.

3 Likes

Just to follow up on what @tqchen summarized previously, here’s my understanding:

frontend converters

We want users who write frontend converters be aware that certain operators are stateful. We can encourage them to write these operations in A1 style. For instance:

def _mx_dropout_train(inputs, attrs, module):
    rate = attrs.get_float("p", 0.5)
    global_state = module['prng']
    state_ref = relay.RefCreate(global_state)
    read_state = relay.RefRead(state_ref)
    # the dropout_train operator outputs both y and the new state 
    y_state = _op.nn.dropout_train(inputs[0], read_state, rate=rate)
    # write back new state, return y 
    write_state = relay.RefWrite(state_ref, y_state[1])
    y = relay.Let(relay.var('ref_write'), write_state, y_state[0])
    return y

where module['prng'] is a global variable representing the PRNG state in the module. As of now, global variables currently are only used to represent functions. We need to extend it to represent the random state, too.

rewriting A1-style programs to A2 -style ones

Let’s say we have a function below with stateful ops:

def @func1(%x) {
  %0 = ref(@prng_state);
  %1 = %0^;
  %2 = nn.dropout_train(%x, %1, rate=0.7f)
  %3 = %2.1;
  let %ref_write: () = (%0 := %3);
  %2.0
}

In the rewriting pass, we detect that the global random state is used, and replace its references to the following:

def @func1_rewritten(%x, %state) {
  %2 = nn.dropout_train(%x, %state, rate=0.7f)
  (%2.0, %2.1)
}

Note that the function output type is changed to a tuple containing the new state. Meanwhile we need to update all CallNodes for this function accordingly. Here is another example:

def @long_func(%x) {
  %0 = ref(@prng_state);
  %1 = %0^;
  %2 = nn.dropout_train(%x, %1, rate=0.7f)
  %3 = %2.1;
  %4 = (
    let %ref_write1: () = (%0 := %3);
    %2.0
  );
  %5 = %0^;
  %6 = nn.dropout_train(%4, %5, rate=0.1f) 
  %7 = %6.1;
  let %ref_write: () = (%0 := %7);
  %6.0
}

===> 

def @long_func_rewritten(%x, %state) {
  %2 = nn.dropout_train(%x, %state, rate=0.7f)
  %3 = %2.1;
  %4 = %2.0;
  %6 = nn.dropout_train(%4, %3, rate=0.1f) 
  (%6.1, %6.0)
}

Note that the pass implementation requires tracking the latest value of the global variable within each scope. For instance, the program below:

def @func2(%x, %y) { # returns tensor
  if (%x) {
    add(%x, %y)
  } else {
    func1(%y)
  }
}

would be rewritten to:

def @func2(%x, %y, %state) {  # returns (tensor, state) for both branches
  if (%x) {
    (add(%x, %y), %state) # the original state is also returned
  } else {
    func1_rewritten(%y, %state) # returns the new state
  }
}

Since the pass requires evaluations within each scope, it would be easier to implement the pass after the program is already transformed to the bblock form.

discussions

what type do we use for the random state?

  • option a: use the empty tuple type. The runtime actually uses the global state, and it relies on the deterministic execution order of the program to ensure reproducibility.
  • option b: add a new type (e.g. TypeRandState), and the random state Object actually carries the data structure used for generating random numbers (e.g. std::mt19937). The state is passed around in the program, and invoking an operator with the same state object always leads to the same deterministic outputs.

@junrushao @haichen @MarisaKirisame @ziheng would you like to provide some suggestions/comments?

3 Likes

Hey @eric-haibin-lin, Thank you for the valuable examples!

I was thinking if we can further alleviate frontend developers’ burden in writing operators with side effects. For example, frontend developers are only required to produce the program like below:

fn @just_a_dropout(%x) {
  let %y = _stateful_dropout(%x);
  %y
}

and then we provide a pass to replace _stateful_dropout properly:

fn @just_a_dropout(%x) {
  let %prng = RefRead(%global_var_prng_ref);
  let (%y, %new_prng) = dropout(%x, %prng);
  RefWrite(%global_var_prng_ref, %new_prng);
  %y
}

Note that this approach requires that we ignore potential dependency issues. For example, imagine if there are two parallel dropouts with no dependency between each other, then this approach would add an arbitrary dependency. However, in our case of neural networks, it doesn’t really matter.

1 Like

I dont have problems with rewriting to state passing style, and I agree with the potential type issues to the PRNG key. As @eric-haibin-lin mentioned, it can be either a unit type (empty tuple), or a new type to be introduced to relay. Would love to hear more from @tqchen, @MarisaKirisame and @jroesch.

Hey everyone, reviving this thread as @tkonolige, @jroesch, @antinucleon and I have been experimenting with adding some PRNG support to Relay.

While nothing is finalized, we are currently trying the Jax approach of explicit user-side PRNG key management. The reasoning is as follows: in general, most networks use PRNG quite simply, so the user defining the network

  1. in Relay, can easily pass around a key while they build the network
  2. in importers, can have a mutable field storing the key (in Python) and replace the key by splitting after each use.

For an example taken from Relay (VGG) and modified as suggested:

def get_classifier(input_data, num_classes, prng_key):
    """Get VGG classifier layers as fc layers."""
    left, right = relay.random.split(prng_key)
    flatten = relay.nn.batch_flatten(data=input_data)
    fc6 = wrapper.dense_add_bias(data=flatten, units=4096, name="fc6")
    relu6 = relay.nn.relu(data=fc6)
    drop6 = relay.nn.dropout(data=relu6, rate=0.5, key=left)
    fc7 = wrapper.dense_add_bias(data=drop6, units=4096, name="fc7")
    relu7 = relay.nn.relu(data=fc7)
    drop7 = relay.nn.dropout(data=relu7, rate=0.5, key=right)
    fc8 = wrapper.dense_add_bias(data=drop7, units=num_classes, name="fc8")
    return fc8

Note that we must only use a PRNG key once (treat it as a linear resource), but here we don’t need a key again so we can use both results from random.split. Then to use this in defining the full network, we can simply write (for example)

classifier = get_classifier(feature, num_classes, relay.random.key(0))

For an example with an importer (this is Caffe):

    def convert_dropout(self, op):
        """ Convert Dropout layer """
        inputs = op.bottom
        input_name = inputs[0]
        next_key, dropout_key = _op.random.split(self.prng_key)
        self.prng_key = next_key  # we can wrap this in some helper method

        params = dict()
        dropout_params = op.dropout_param

        params["rate"] = dropout_params.dropout_ratio

        in_expr = self.exp_tab.get_expr(input_name)
        out = _op.nn.dropout(in_expr, **params, key=dropout_key)
        return out

Would love to hear your thoughts and if I missed any kind of edge case here! I’m also happy to try and write some pseudocode examples for more complicated use cases if anyone is interested.

EDIT: I forgot to mention that, as @eric-haibin-lin mentioned, this means we cannot plug in off-the-shelf PRNG kernels. However, we have written a splittable PRNG kernel in TIR (Threefry), so some of the ground work is done. This will also let us run on GPU.

1 Like