Skip to content

Initialization and Variables #237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
rnett opened this issue Mar 11, 2021 · 0 comments
Open

Initialization and Variables #237

rnett opened this issue Mar 11, 2021 · 0 comments

Comments

@rnett
Copy link
Contributor

rnett commented Mar 11, 2021

Initialization and Variables

As I'm working on a tf.function style API, I keep running into issues with initialization and variables, since they are all very interconnected. This is a proposal to streamline initialization and variable handling with an eye towards functions. I'm not going to get too far into the function API itself here other than assuming it has builders like ConcreteFunction's and builds a ConcreteFunction depending on the input signature (like Python's). I'm also only working with the TF2 style resource variables (which are currently waiting on gradient support). Current draft is #179, I would change it to reflect what is described here.

I'd like to add the new variables at some point (once their design here is finalized) even w/o gradient support so that I can work on functions, which rely on them. Ideally this would be after 0.3.0 though, and either way we can mark them as "do not use" for now.

Python

Initialization

Note: auto-capture of eager variables by functions is allowed (i.e. passing them to ops).

Handled by init_scope:

  • Graph: similar to what we do, a special scope which can then be called
  • Eager session: just run them
  • Function: an eager session owned by the function graph. Operands from that session can be used in the function w/o anything extra.

Variables

Handle creation and initialization is done in init_scope. Functions also have a fallback initialization using cond/if and isInitialized if the initial value can't be lifted to the init scope (i.e. if it depends on inputs). This has XLA compatibility issues.

This results in all function variables being global, and in the limitation on declaring variables on subsequent runs.

Functions

Python uses a graph subclass for special handling, and we will need to as well. It also inserts control dependencies between stateful ops, which we can and should do as well.

My Proposal

Initialization

First, we should do initialization like python does: have a special initScope() that tracks ops for graph, runs them for eager, and is an eager session for functions.

Something like:

fun makeVariable(tf: Ops) {
    val variable = tf.initScope().Variable()
}

However, this has a rather large issue: functions will be re-executed multiple times (when the input signature changes), re-creating this variable multiple times. The easy answer is to do something like:

private var cache: Variable? = null
fun makeVariable(tf: Ops) {
    if (cache == null)
        cache = tf.initScope().Variable()
    val variable = cache!!
}

however, this is clunky, and requires users to know if they are in a function context or not (since doing something like this isn't necessary most of the time in other contexts).

Something better and more environment agnostic is to expose a isInit() boolean that is always true for non-functions and only true on functions first run. Then this becomes:

private var cache: Variable? = null
fun makeVariable(tf: Ops) {
    if (tf.isInit())
        cache = tf.initScope().Variable()
    val variable = cache!!
}

or even

private var cache: Variable? = null
fun makeVariable(tf: Ops) {
    tf.init { initScope -> cache = initScope.Variable() }
    val variable = cache!!
}

(this type of assignment from the lambda won't work in Java, which is why we need the boolean for the if statement).

With a compiler plugin, in Kotlin we can use Compose-style call site memoization, like:

fun makeVariable(tf: Ops) {
    val variable = remember { initScope -> initScope.Variable() }
}

We could have a manually keyed version for Java as well, although that could be a bit fraught for collisions.

Variables

Variables in python is rather inconsistent, and I want to do better here. Consider a basic if odd addition function:

# noinspection PyUnresolvedReferences
def add(a, b):
  sum = tf.Variable(0)
  sum += a
  sum += b
  return sum


c = add(1, 2)

In eager mode, this works as expected over multiple calls, returning the sum each time.

In graph mode, it's a bit weird. Calling the function multiple times works fine. However, calling uses of the function re-uses the variable, meaning the first time c is evaluated it is 3, the second time it is 6. Worse, getting add to work as expected requires adding control dependencies for the assigns. Notably, this isn't necessary in functions because control dependencies are automatically added between stateful ops in the order they are created. I think it's worth considering doing the same for our Graphs (with an opt out).

In functions, it works like eager mode the first time and then errors.

I'd propose separating local and global variables. Local variables are not created in initScope and add control dependencies between init, assignments, and reads (and are not trainable). So

fun add(tf: Ops, a: Operand<TFlaoat32>, b: Operand<TFloat32>): Operand<TFloat32> {
    val sum = tf.LocalVariable(0)
    sum += a
    sum += b
    return sum.read()
}


c = add(1, 2)

works as expected.

Global variables would work like the python variables (created and init'd in initScope, once). While I would like to do something to make them match across execution environments, that would require some form of environment-scoped memoization, which would require environment-unique names for each variable (or ASM rewriting to add a call-site key), which imo isn't worth it. Memoization can be better handled by Layer APIs or similar. If anyone has ideas on how to do this though, it would be great to have.

I'm not sure how useful local variables would be, but given that it's not hard to do I think it's worth including.

I'd also like to not support variables without initial values (on the API side at least), since I can't think of a way where it would be useful and it's a bit of a pain to implement. It's doable if there's a need though.

Saving and Loading

This is the hardest part. Saving graphs is easy enough, just add everything in the init scope as control dependencies to the restore op. Saving functions is much harder since we need a way to move the eager init scope into the saved graph. Not sure how Python does it, we can probably use something similar. At a minimum we need the ability to copy an op to a new context though, which we don't currently have (there are no TFE_OpGetAttr* methods).

Loading the graph is also easy, we just add the restore op to the init scope. Loading functions is trickier, since for functions in a bundle we need to figure out which variables it uses, find the initializers for those variables, and then lift them from the graph to the function's init scope. It is possible though, and until it's implemented we can point users towards the loaded session based methods.

For now I'd just not support saving and loading individual functions, and have users use the loaded session, which is what is done now (the "
loaded" ConcreteFunctions use the same session and aren't callable with anything but tensors).

Requested Feedback

I'd like feedback on everything, really, but especially on whether:

  • we can add the new variables even w/o gradient support if properly marked and documented
  • to consider auto-adding control deps in Graph w/ and opt out (like is done in functions)
  • local variables are worth adding
  • no-initial value variables are worth adding
  • limiting loading to session (as is done now) is fine for now
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant