-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff][Sema] Properly type-check differentiability of parameters and results in differentiable function type declarations #41174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
CC: @AnthonyLatsis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you planning to add some tests? I am no AutoDiff expert, so I can only speak for basic stuff.
lib/AST/Type.cpp
Outdated
@@ -6021,8 +6021,7 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) { | |||
newElts.push_back(elt.getWithType(eltSpace->getType())); | |||
} | |||
if (newElts.empty()) | |||
return cache( | |||
TangentSpace::getTuple(ctx.TheEmptyTupleType->castTo<TupleType>())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to add a "non-empty tuple" assertion to TangentSpace::getTuple
to prevent this from happening elsewhere, if this is indeed correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like a good idea. I’ll change that before this PR gets merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
You’re right! I should add tests! All the crashers (and the incorrect error diagnostic) I cited in SR-15808 should be tested. Fixing the incorrect error diagnostic will probably require more changes to the C++ code, but within the same function as some other changes. I’m modifying how a differentiable closure is type-checked. |
lib/AST/AutoDiff.cpp
Outdated
TangentSpace TangentSpace::getTuple(TupleType *tupleTy) { | ||
if (tupleTy->getElements().size() == 0) { | ||
llvm::report_fatal_error("Attempted to get tangent space of empty tuple."); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: We mainly use assert
to stand guard of internal invariants. llvm::report_fatal_error
is meant for something the user is to blame for, like passing a flag and violating its contract.
lib/AST/AutoDiff.cpp
Outdated
@@ -366,6 +366,13 @@ GenericSignature autodiff::getDifferentiabilityWitnessGenericSignature( | |||
return derivativeGenSig; | |||
} | |||
|
|||
TangentSpace TangentSpace::getTuple(TupleType *tupleTy) { | |||
if (tupleTy->getElements().size() == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getNumElements
is more efficient. Or you could do tupleTy->isVoid()
, which does the check for you.
The inconsistent error diagnostic requires attention in a separate pull request. It should automatically consider non-differentiable stuff as like Also, there's a runtime crash here, where type checking should have prevented it from compiling. This runs successfully at runtime: typealias MyType = @differentiable(reverse) (inout Float, @noDerivative Float, @noDerivative Int) -> Void
@differentiable(reverse)
func myFunc(_ x: inout Float, _ q: @noDerivative Float, _ y: @noDerivative Int) -> Void {
x
}
print(myFunc as MyType) While this crashes, by adding an extra line that does nothing: typealias MyType = @differentiable(reverse) (inout Float, @noDerivative Float, @noDerivative Int) -> Void
@differentiable(reverse)
func myFunc(_ x: inout Float, _ q: @noDerivative Float, _ y: @noDerivative Int) -> Void {
x
}
typealias MyType2 = @differentiable(reverse) (inout Float, @noDerivative Float, @noDerivative Int) -> Void
print(myFunc as MyType) I got this crash by narrowing down the wierd error diagnostic, and I think they're related. |
This one-liner reduced from your example crashes during type-checking with my build: import _Differentiation
@differentiable(reverse) func myFunc(_ x: inout Float) -> Void {} |
I don't see the crash happening. Did you build from my branch, or from a previous toolchain? The comment above was tested on the January 9, 2022 toolchain, not this branch. The whole AutoDiff test suite passes on my branch as well. |
No, this is a 2-day-old trunk. |
Please try branch then. Or, manually insert my changes onto main (there aren't that many). |
Is this true? It's odd, yet a reasonable thing to take the derivative of the following enum type
Which is isomorphic to Void. Are we missing a builtin conformance to Differentiable? |
|
Yes it does. In fact, it's
There is only one such function and it is differentiable (in the mathematical sense) because it's necessarily constant. Consider that the definition of differentiation involves quantifying over points not equal to whatever |
Then how else do you suggest we resolve this bug? |
Swift classifies |
This seems like a policy choice. Do you know what Julia or PyTorch do here? |
I don't think that's relevant if the Swift compiler already has treated it as non-differentiable for years. All unit tests pass with the branch how it is, without modifications. Also, you could manually conform your enumeration to Differentiable. Doing that for Void is a compile failure because tuples can't conform to protocols. Furthermore, if Void was truly differentiable, every differentiable function that returns nothing and has an inout parameter would return two semantic results (a compile failure). |
We have begun to build the machinery for this with BuiltinProtocolConformance. It's okay for us to declare this case as unsupported, but there needs to be a diagnostic for it - I already see one for differentiable enums.
That's a solid point, but autodiff already rejects functions with multiple result types like this. I'm not asking for that to change, rather to consider that the reasoning behind us rejecting differentiable functions in |
Could you give me more context on how we might implement a special diagnostic for Void? For example, would we say something like:
|
Void
no longer classified as differentiableVoid
as non-differentiable
This comment was marked as off-topic.
This comment was marked as off-topic.
I think we can support it given the machinery in the compiler today. You'll need to extend a builtin conformance to these types. It's a lot of plumbing, but taking a look at how we do this for Sendable should give you a feel for how to do this for non-nominal types in general. |
There's a way to solve these crashes without making |
6d66e8a
to
d2f5188
Compare
Void
as non-differentiableinout
or Void
inout
or Void
* a commit * another commit * Update lib/Sema/TypeCheckType.cpp * Update lib/Sema/TypeCheckType.cpp
78a5b82
to
4b8312d
Compare
Resolve ambiguities between lack of a return type and lack of a differentiable return type. If a function returns
Void
, it returns no values and therefore returns no differentiable values. This rule allows type checking to mirror how it processes conventional functions, despiteVoid
being differentiable.Resolves three issues, listed below. The SR-15818 regression test was modified to avoid triggering SR-15823, an unrelated compiler + runtime crasher.