Skip to content

[AutoDiff][Sema] Properly type-check differentiability of parameters and results in differentiable function type declarations #41174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from

Conversation

philipturner
Copy link
Contributor

@philipturner philipturner commented Feb 3, 2022

Resolve ambiguities between lack of a return type and lack of a differentiable return type. If a function returns Void, it returns no values and therefore returns no differentiable values. This rule allows type checking to mirror how it processes conventional functions, despite Void being differentiable.

// Before this PR, the following declaration might crash the compiler.
typealias FuncType = @differentiable(reverse) (inout Float, Float) -> Void

// While the following function has the same signature and compiles just fine.
@differentiable(reverse)
func foo(x: inout Float, y: Float) {
  x += y
}

Resolves three issues, listed below. The SR-15818 regression test was modified to avoid triggering SR-15823, an unrelated compiler + runtime crasher.

@philipturner
Copy link
Contributor Author

philipturner commented Feb 3, 2022

CC: @AnthonyLatsis

@philipturner
Copy link
Contributor Author

#41128

Copy link
Collaborator

@AnthonyLatsis AnthonyLatsis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you planning to add some tests? I am no AutoDiff expert, so I can only speak for basic stuff.

lib/AST/Type.cpp Outdated
@@ -6021,8 +6021,7 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
newElts.push_back(elt.getWithType(eltSpace->getType()));
}
if (newElts.empty())
return cache(
TangentSpace::getTuple(ctx.TheEmptyTupleType->castTo<TupleType>()));
Copy link
Collaborator

@AnthonyLatsis AnthonyLatsis Feb 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to add a "non-empty tuple" assertion to TangentSpace::getTuple to prevent this from happening elsewhere, if this is indeed correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a good idea. I’ll change that before this PR gets merged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@philipturner
Copy link
Contributor Author

philipturner commented Feb 3, 2022

Are you planning to add some tests? I am no AutoDiff expert, so I can only speak for basic stuff.

You’re right! I should add tests! All the crashers (and the incorrect error diagnostic) I cited in SR-15808 should be tested.

Fixing the incorrect error diagnostic will probably require more changes to the C++ code, but within the same function as some other changes. I’m modifying how a differentiable closure is type-checked.

TangentSpace TangentSpace::getTuple(TupleType *tupleTy) {
if (tupleTy->getElements().size() == 0) {
llvm::report_fatal_error("Attempted to get tangent space of empty tuple.");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: We mainly use assert to stand guard of internal invariants. llvm::report_fatal_error is meant for something the user is to blame for, like passing a flag and violating its contract.

@@ -366,6 +366,13 @@ GenericSignature autodiff::getDifferentiabilityWitnessGenericSignature(
return derivativeGenSig;
}

TangentSpace TangentSpace::getTuple(TupleType *tupleTy) {
if (tupleTy->getElements().size() == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getNumElements is more efficient. Or you could do tupleTy->isVoid(), which does the check for you.

@philipturner
Copy link
Contributor Author

The inconsistent error diagnostic requires attention in a separate pull request. It should automatically consider non-differentiable stuff as like @noDerivative. But, when I bypass that check, it crashes in a later compilation stage. This restriction was never outlined in the Differentiable Programming Manifesto.

Also, there's a runtime crash here, where type checking should have prevented it from compiling. This runs successfully at runtime:

typealias MyType = @differentiable(reverse) (inout Float, @noDerivative Float, @noDerivative Int) -> Void

@differentiable(reverse)
func myFunc(_ x:  inout Float, _ q: @noDerivative Float, _ y: @noDerivative Int) -> Void {
  x
}
print(myFunc as MyType)

While this crashes, by adding an extra line that does nothing:

typealias MyType = @differentiable(reverse) (inout Float, @noDerivative Float, @noDerivative Int) -> Void

@differentiable(reverse)
func myFunc(_ x:  inout Float, _ q: @noDerivative Float, _ y: @noDerivative Int) -> Void {
  x
}
typealias MyType2 = @differentiable(reverse) (inout Float, @noDerivative Float, @noDerivative Int) -> Void
print(myFunc as MyType)

I got this crash by narrowing down the wierd error diagnostic, and I think they're related.

@AnthonyLatsis
Copy link
Collaborator

AnthonyLatsis commented Feb 3, 2022

This one-liner reduced from your example crashes during type-checking with my build:

import _Differentiation

@differentiable(reverse) func myFunc(_ x:  inout Float) -> Void {}

@philipturner
Copy link
Contributor Author

philipturner commented Feb 3, 2022

I don't see the crash happening. Did you build from my branch, or from a previous toolchain? The comment above was tested on the January 9, 2022 toolchain, not this branch.

The whole AutoDiff test suite passes on my branch as well.

@AnthonyLatsis
Copy link
Collaborator

Did you build from my branch, or from a previous toolchain?

No, this is a 2-day-old trunk.

@philipturner
Copy link
Contributor Author

Please try branch then. Or, manually insert my changes onto main (there aren't that many).

@CodaFi
Copy link
Contributor

CodaFi commented Feb 3, 2022

Void should never have been differentiable in the first place.

Is this true? It's odd, yet a reasonable thing to take the derivative of the following enum type

enum Top: Differentiable {
  case one
}

Which is isomorphic to Void. Are we missing a builtin conformance to Differentiable?

@philipturner
Copy link
Contributor Author

philipturner commented Feb 3, 2022

Top doesn't have a tangent vector. Try using it as the only input and output of a @differentiable function.

@CodaFi
Copy link
Contributor

CodaFi commented Feb 3, 2022

Top doesn't have a tangent vector.

Yes it does. In fact, it's (), an (arbitrary) representation of a zero-dimensional vector.

Try using it as the only input and output of a @differentiable function.

There is only one such function and it is differentiable (in the mathematical sense) because it's necessarily constant. Consider that the definition of differentiation involves quantifying over points not equal to whatever () is represented by in your equivalent manifold. There... aren't any of those.

@philipturner
Copy link
Contributor Author

Then how else do you suggest we resolve this bug?

@philipturner
Copy link
Contributor Author

Swift classifies @differentiable (Void) -> Void as an error.

@CodaFi
Copy link
Contributor

CodaFi commented Feb 3, 2022

This seems like a policy choice. Do you know what Julia or PyTorch do here?

@philipturner
Copy link
Contributor Author

I don't think that's relevant if the Swift compiler already has treated it as non-differentiable for years. All unit tests pass with the branch how it is, without modifications. Also, you could manually conform your enumeration to Differentiable. Doing that for Void is a compile failure because tuples can't conform to protocols. Furthermore, if Void was truly differentiable, every differentiable function that returns nothing and has an inout parameter would return two semantic results (a compile failure).

@CodaFi
Copy link
Contributor

CodaFi commented Feb 3, 2022

because tuples can't conform to protocols.

We have begun to build the machinery for this with BuiltinProtocolConformance. It's okay for us to declare this case as unsupported, but there needs to be a diagnostic for it - I already see one for differentiable enums.

Furthermore, if Void was truly differentiable, every differentiable function that returns nothing and has an inout parameter would return two semantic results (a compile failure).

That's a solid point, but autodiff already rejects functions with multiple result types like this. I'm not asking for that to change, rather to consider that the reasoning behind us rejecting differentiable functions in R^{0} has to be stronger than looking at the implementation and declaring it moot, or merely getting the tests to pass.

@philipturner
Copy link
Contributor Author

philipturner commented Feb 3, 2022

there needs to be a diagnostic for it - I already see one for differentiable enums.

Could you give me more context on how we might implement a special diagnostic for Void? For example, would we say something like:

Differentiation of R^{0} types coming in a future version of Swift

@philipturner philipturner changed the title [AutoDiff] Make Void no longer classified as differentiable [AutoDiff] Consistently classify Void as non-differentiable Feb 5, 2022
@philipturner

This comment was marked as off-topic.

@CodaFi
Copy link
Contributor

CodaFi commented Feb 11, 2022

please elaborate on the best way to address classifying Void as differentiable.

I think we can support it given the machinery in the compiler today. You'll need to extend a builtin conformance to these types. It's a lot of plumbing, but taking a look at how we do this for Sendable should give you a feel for how to do this for non-nominal types in general.

@philipturner
Copy link
Contributor Author

There's a way to solve these crashes without making Void un-differentiable.

@philipturner philipturner changed the title [AutoDiff] Consistently classify Void as non-differentiable [AutoDiff][Sema] Correctly process differentiable function type declarations with inout or Void Apr 21, 2022
@philipturner philipturner changed the title [AutoDiff][Sema] Correctly process differentiable function type declarations with inout or Void [AutoDiff][Sema] Properly type-check differentiability of parameters and results in differentiable function type declarations Apr 21, 2022
@philipturner philipturner marked this pull request as ready for review April 24, 2022 06:20
@asl asl added the AutoDiff label May 3, 2022
@philipturner philipturner requested a review from AnthonyLatsis May 4, 2022 21:39
@philipturner
Copy link
Contributor Author

#56412 might refer to #38781, which is why I am not marking that issue as fully resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
4 participants