You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As I'm working on a tf.function style API, I keep running into issues with initialization and variables, since they are all very interconnected. This is a proposal to streamline initialization and variable handling with an eye towards functions. I'm not going to get too far into the function API itself here other than assuming it has builders like ConcreteFunction's and builds a ConcreteFunction depending on the input signature (like Python's). I'm also only working with the TF2 style resource variables (which are currently waiting on gradient support). Current draft is #179, I would change it to reflect what is described here.
I'd like to add the new variables at some point (once their design here is finalized) even w/o gradient support so that I can work on functions, which rely on them. Ideally this would be after 0.3.0 though, and either way we can mark them as "do not use" for now.
Python
Initialization
Note: auto-capture of eager variables by functions is allowed (i.e. passing them to ops).
Handled by init_scope:
Graph: similar to what we do, a special scope which can then be called
Eager session: just run them
Function: an eager session owned by the function graph. Operands from that session can be used in the function w/o anything extra.
Variables
Handle creation and initialization is done in init_scope. Functions also have a fallback initialization using cond/if and isInitialized if the initial value can't be lifted to the init scope (i.e. if it depends on inputs). This has XLA compatibility issues.
This results in all function variables being global, and in the limitation on declaring variables on subsequent runs.
Functions
Python uses a graph subclass for special handling, and we will need to as well. It also inserts control dependencies between stateful ops, which we can and should do as well.
My Proposal
Initialization
First, we should do initialization like python does: have a special initScope() that tracks ops for graph, runs them for eager, and is an eager session for functions.
Something like:
funmakeVariable(tf:Ops) {
val variable = tf.initScope().Variable()
}
However, this has a rather large issue: functions will be re-executed multiple times (when the input signature changes), re-creating this variable multiple times. The easy answer is to do something like:
privatevar cache:Variable?=nullfunmakeVariable(tf:Ops) {
if (cache ==null)
cache = tf.initScope().Variable()
val variable = cache!!
}
however, this is clunky, and requires users to know if they are in a function context or not (since doing something like this isn't necessary most of the time in other contexts).
Something better and more environment agnostic is to expose a isInit() boolean that is always true for non-functions and only true on functions first run. Then this becomes:
privatevar cache:Variable?=nullfunmakeVariable(tf:Ops) {
if (tf.isInit())
cache = tf.initScope().Variable()
val variable = cache!!
}
In eager mode, this works as expected over multiple calls, returning the sum each time.
In graph mode, it's a bit weird. Calling the function multiple times works fine. However, calling uses of the function re-uses the variable, meaning the first time c is evaluated it is 3, the second time it is 6. Worse, getting add to work as expected requires adding control dependencies for the assigns. Notably, this isn't necessary in functions because control dependencies are automatically added between stateful ops in the order they are created. I think it's worth considering doing the same for our Graphs (with an opt out).
In functions, it works like eager mode the first time and then errors.
I'd propose separating local and global variables. Local variables are not created in initScope and add control dependencies between init, assignments, and reads (and are not trainable). So
funadd(tf:Ops, a:Operand<TFlaoat32>, b:Operand<TFloat32>): Operand<TFloat32> {
val sum = tf.LocalVariable(0)
sum += a
sum += b
return sum.read()
}
c = add(1, 2)
works as expected.
Global variables would work like the python variables (created and init'd in initScope, once). While I would like to do something to make them match across execution environments, that would require some form of environment-scoped memoization, which would require environment-unique names for each variable (or ASM rewriting to add a call-site key), which imo isn't worth it. Memoization can be better handled by Layer APIs or similar. If anyone has ideas on how to do this though, it would be great to have.
I'm not sure how useful local variables would be, but given that it's not hard to do I think it's worth including.
I'd also like to not support variables without initial values (on the API side at least), since I can't think of a way where it would be useful and it's a bit of a pain to implement. It's doable if there's a need though.
Saving and Loading
This is the hardest part. Saving graphs is easy enough, just add everything in the init scope as control dependencies to the restore op. Saving functions is much harder since we need a way to move the eager init scope into the saved graph. Not sure how Python does it, we can probably use something similar. At a minimum we need the ability to copy an op to a new context though, which we don't currently have (there are no TFE_OpGetAttr* methods).
Loading the graph is also easy, we just add the restore op to the init scope. Loading functions is trickier, since for functions in a bundle we need to figure out which variables it uses, find the initializers for those variables, and then lift them from the graph to the function's init scope. It is possible though, and until it's implemented we can point users towards the loaded session based methods.
For now I'd just not support saving and loading individual functions, and have users use the loaded session, which is what is done now (the "
loaded" ConcreteFunctions use the same session and aren't callable with anything but tensors).
Requested Feedback
I'd like feedback on everything, really, but especially on whether:
we can add the new variables even w/o gradient support if properly marked and documented
to consider auto-adding control deps in Graph w/ and opt out (like is done in functions)
local variables are worth adding
no-initial value variables are worth adding
limiting loading to session (as is done now) is fine for now
The text was updated successfully, but these errors were encountered:
Initialization and Variables
As I'm working on a
tf.function
style API, I keep running into issues with initialization and variables, since they are all very interconnected. This is a proposal to streamline initialization and variable handling with an eye towards functions. I'm not going to get too far into the function API itself here other than assuming it has builders likeConcreteFunction
's and builds aConcreteFunction
depending on the input signature (like Python's). I'm also only working with the TF2 style resource variables (which are currently waiting on gradient support). Current draft is #179, I would change it to reflect what is described here.I'd like to add the new variables at some point (once their design here is finalized) even w/o gradient support so that I can work on functions, which rely on them. Ideally this would be after 0.3.0 though, and either way we can mark them as "do not use" for now.
Python
Initialization
Note: auto-capture of eager variables by functions is allowed (i.e. passing them to ops).
Handled by
init_scope
:Variables
Handle creation and initialization is done in
init_scope
. Functions also have a fallback initialization usingcond
/if
andisInitialized
if the initial value can't be lifted to the init scope (i.e. if it depends on inputs). This has XLA compatibility issues.This results in all function variables being global, and in the limitation on declaring variables on subsequent runs.
Functions
Python uses a graph subclass for special handling, and we will need to as well. It also inserts control dependencies between stateful ops, which we can and should do as well.
My Proposal
Initialization
First, we should do initialization like python does: have a special
initScope()
that tracks ops for graph, runs them for eager, and is an eager session for functions.Something like:
However, this has a rather large issue: functions will be re-executed multiple times (when the input signature changes), re-creating this variable multiple times. The easy answer is to do something like:
however, this is clunky, and requires users to know if they are in a function context or not (since doing something like this isn't necessary most of the time in other contexts).
Something better and more environment agnostic is to expose a
isInit()
boolean that is always true for non-functions and only true on functions first run. Then this becomes:or even
(this type of assignment from the lambda won't work in Java, which is why we need the boolean for the if statement).
With a compiler plugin, in Kotlin we can use Compose-style call site memoization, like:
We could have a manually keyed version for Java as well, although that could be a bit fraught for collisions.
Variables
Variables in python is rather inconsistent, and I want to do better here. Consider a basic if odd addition function:
In eager mode, this works as expected over multiple calls, returning the sum each time.
In graph mode, it's a bit weird. Calling the function multiple times works fine. However, calling uses of the function re-uses the variable, meaning the first time
c
is evaluated it is 3, the second time it is 6. Worse, gettingadd
to work as expected requires adding control dependencies for the assigns. Notably, this isn't necessary in functions because control dependencies are automatically added between stateful ops in the order they are created. I think it's worth considering doing the same for our Graphs (with an opt out).In functions, it works like eager mode the first time and then errors.
I'd propose separating local and global variables. Local variables are not created in
initScope
and add control dependencies between init, assignments, and reads (and are not trainable). Soworks as expected.
Global variables would work like the python variables (created and init'd in
initScope
, once). While I would like to do something to make them match across execution environments, that would require some form of environment-scoped memoization, which would require environment-unique names for each variable (or ASM rewriting to add a call-site key), which imo isn't worth it. Memoization can be better handled by Layer APIs or similar. If anyone has ideas on how to do this though, it would be great to have.I'm not sure how useful local variables would be, but given that it's not hard to do I think it's worth including.
I'd also like to not support variables without initial values (on the API side at least), since I can't think of a way where it would be useful and it's a bit of a pain to implement. It's doable if there's a need though.
Saving and Loading
This is the hardest part. Saving graphs is easy enough, just add everything in the init scope as control dependencies to the restore op. Saving functions is much harder since we need a way to move the eager init scope into the saved graph. Not sure how Python does it, we can probably use something similar. At a minimum we need the ability to copy an op to a new context though, which we don't currently have (there are no
TFE_OpGetAttr*
methods).Loading the graph is also easy, we just add the restore op to the init scope. Loading functions is trickier, since for functions in a bundle we need to figure out which variables it uses, find the initializers for those variables, and then lift them from the graph to the function's init scope. It is possible though, and until it's implemented we can point users towards the loaded session based methods.
For now I'd just not support saving and loading individual functions, and have users use the loaded session, which is what is done now (the "
loaded"
ConcreteFunction
s use the same session and aren't callable with anything but tensors).Requested Feedback
I'd like feedback on everything, really, but especially on whether:
The text was updated successfully, but these errors were encountered: