Skip to content

[AutoDiff] WIP - Adding support for differentiating wrt inout parameters. #25687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from

Conversation

eaplatanios
Copy link

@rxwei This is a very first step but I thought it'd be good to start making progress here and keep all the related discussions here.

@eaplatanios
Copy link
Author

@rxwei in the last commit I added a helper for handling inout types when compute the JVP and VJP types. I'll pause now until you take a look and tell me if this looks ok and what the next steps are.

@eaplatanios
Copy link
Author

@rxwei @dan-zheng I also added support for inout wrt parameters to AttributeChecker::visitDifferentiableAttr. It doesn't yet support propagating property @differentiable to both the setter and getter, but will add support for that once you confirm that my current approach is correct/reasonable.

@eaplatanios
Copy link
Author

I actually added support for propagating the @differentiable attribute to the getters and setters of properties. I'm going to bed now, but once you verify that this looks good, it'd be good to go through the next steps in a bit more detail because I'm not familiar at all with the SIL optimizer phases. I just noticed that we need to make changes here and here probably, but not sure what those should be exactly.

@eaplatanios
Copy link
Author

I also ran tests to see if all is still good and everything passes except for an unexpected error message. Adding @differentiable to property setters results in cannot differentiate void function '_' with no inout wrt parameters, which is weird because it means that setters do not have inout parameters, or that automatically populating the wrt parameters currently ignores inout parameters.

@eaplatanios eaplatanios changed the title [TF] WIP - Adding support for differentiating wrt inout parameters. [AutoDiff] WIP - Adding support for differentiating wrt inout parameters. Jun 23, 2019
@eaplatanios
Copy link
Author

@rxwei I started looking a bit into the SIL optimization phases to try and figure out what to do next, but I had one question about function parameters: what's the difference between direct and indirect parameters? I assume I should use something like SILParameterInfo::isIndirectInOut() to detect the inout arguments?

@rxwei
Copy link
Contributor

rxwei commented Jun 24, 2019

Indirect parameters are parameters that need to be handled as buffers due to various reasons (unknown size, inout, etc). The abstraction difference section in the SIL reference has some explanation. inout parameters are always indirect.

@eaplatanios
Copy link
Author

@rxwei I've going through the code and I'm wondering if we also need to change TypeBase::getAutoDiffAssociatedTangentSpace so that it can handle inout types. If so, I'm not sure how to change it because I'm not sure whether inout T will be interpreted as conforming to Differentiable if T conforms. I believe we should have to chance the Differentiable conformance derivation somewhere to deal with this, so that we can also make sure that the tangent space is inout.

@eaplatanios
Copy link
Author

The reason I'm asking this is because I don't know if this condition will be true for inout types.

@eaplatanios
Copy link
Author

So currently in this PR I have made a modification in AnyFunctionType::getAutoDiffAssociatedFunctionType mapping the inout types like this:

  auto getTangentType = [lookupConformance](Type type) -> Type {
    if (!type->is<InOutType>()) {
      return type->getAutoDiffAssociatedTangentSpace(
          lookupConformance)->getType();
    }
    Type base = type->getInOutObjectType();
    Type tangentBase = base->getAutoDiffAssociatedTangentSpace(
        lookupConformance)->getType();
    return InOutType::get(tangentBase);
  };

Should I instead move this inside TypeBase::getAutoDiffAssociatedTangentSpace so it's propagated wherever else that function is used?

@eaplatanios
Copy link
Author

Actually, I realize I cannot currently do that directly because VectorSpace does not support the notion of inout so in order to move the mapping inside TypeBase::getAutoDiffAssociatedTangentSpace we would need to modify VectorSpace as well. @rxwei please let me know which option you think is preferable.

@rxwei
Copy link
Contributor

rxwei commented Jun 24, 2019

An inout-type doesn't formally exist in the language. It's just a compiler data structure to represent parameters that are inout. A tangent type by itself should never be inout. Instead of adding logic to TypeBase::getAutoDiffAssociatedTangentSpace or writing a function like this, we should handle this specifically in the place where we compute tangent/cotangent's parameter types.

@eaplatanios
Copy link
Author

I see. In that case, I guess my current implementation as part of AnyFunctionType::getAutoDiffAssociatedFunctionType should be ok. Would the next step be to change SILType::isDifferentiable and TypeResolver::isDifferentiableType so they can handle types marked as inout?

@eaplatanios
Copy link
Author

Actually I just pushed a change doing that because I feel like this is indeed necessary.

@eaplatanios
Copy link
Author

Also, I don't think I need to change DifferentiableActivityInfo::analyze after all because the inout arguments will already be marked as useful if they're being differentiated with respect to and so we don't need to add them to the set of output values at this stage. Is this reasoning correct?

@eaplatanios
Copy link
Author

I also started making changes to the VJP emitter but I'm not at all sure if this is the right direction right now so I'll pause until you can take a look at the current set of changes and summarize the appropriate next steps.

@eaplatanios
Copy link
Author

Also cc'ing @dan-zheng because I noticed that you have also worked a lot in this part of the codebase.

@dan-zheng
Copy link
Contributor

dan-zheng commented Feb 22, 2020

I accidentally deleted tensorflow branch, which closed this PR. That was not intentional, sorry!
It would be nice to protect tensorflow branch against deletion to prevent this from happening again.

@dan-zheng dan-zheng reopened this Feb 22, 2020
@dan-zheng
Copy link
Contributor

Thank you @eaplatanios for starting this effort and discussion! Sorry for letting this PR rot.

#30013 just added inout argument differentiation support. The PR description has examples.
#29959 has typing rules for inout derivative functions.

@dan-zheng dan-zheng closed this Feb 23, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants