-
Notifications
You must be signed in to change notification settings - Fork 267
fix(gh-2036): MyPy Errors in numpyro.distributions.transforms
Module
#2066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
This looks great, thanks! There are some minor errors numpyro/distributions/transforms.py:82: error: Missing positional argument "x" in call to "__call__" of "TransformT" [call-arg]
Installing missing stub packages:
numpyro/distributions/transforms.py:84: error: Name "inv" already defined on line 80 [no-redef]
numpyro/distributions/transforms.py:85: error: Incompatible types in assignment (expression has type "ReferenceType[None]", variable has type "TransformT | None") [assignment]
numpyro/distributions/transforms.py:86: error: Incompatible return value type (got "Array | Any | None", expected "TransformT") [return-value]
numpyro/distributions/transforms.py:1550: error: Item "ndarray[tuple[Any, ...], dtype[Any]]" of "ndarray[tuple[Any, ...], dtype[Any]] | Array" has no attribute "at" [union-attr] Do you need some help with these :) ? |
These errors are from numpyro/distributions/transforms.py#L77-L86, and I am not able to understand the significance of different conditions and the weak reference. I think you can look into this matter. I will take this one,
Thank you for offering help ❤️. |
ok! Sounds like a plan! I will try to look at it in the next days :) |
Hey @Qazalbash I gave it a try as in d57d9a6 . MyPy is happy now, maybe you can try it ? The only key change was |
@juanitorduz Thanks for the changes, Do I need to remove the plugin from |
It's depreciated so it's safe to remove |
…ng of inverse transforms Co-authored-by: Juan Orduz <[email protected]>
Co-authored-by: Juan Orduz <[email protected]>
Co-authored-by: Juan Orduz <[email protected]>
ok! I think the tests are failing because a new NNX release and changes in The other tests |
Here is a patch for the first errors #2067 |
Here's another patch #2069 😸 |
…lasses and add PyTree type alias
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Qazalbash, I think we can get around issues by using NumLike just at some specific places. It's fine to use StrictArray at those arrays with dim >= 1. NonScalarArray is a good name for it I guess.
if inv is None: | ||
inv = _InverseTransform(self) | ||
self._inv = weakref.ref(inv) | ||
inv = cast(TransformT, _InverseTransform(self)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems unnecessary to me
inv = _InverseTransform(self) | ||
self._inv = weakref.ref(inv) | ||
inv = cast(TransformT, _InverseTransform(self)) | ||
self._inv = cast(TransformT, weakref.ref(inv)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this cast, it seems incorrect to me?
@Qazalbash do you need any support here to bring this one to the finish line :) ? |
@juanitorduz thanks, but not quite right now. Hopefully, I will sit down tonight and tomorrow to complete it. |
@juanitorduz I thought it would be easy to fix, but to my surprise, these changes are generating more errors than before. Would you like to take over this issue? |
hey, sure! What about if you incrementally push the easiest suggestions (that still work) until you face a problem and then we take it from there? 🙏 |
@juanitorduz I think I have fixed all! See 1e2e670 |
Amazing @Qazalbash ! Thank you! |
This PR contains the resolution of mypy errors passed by #2032, in
numpyro.distributions.transforms
module.There are two cases in particular which I am unable to resolve. You can see them by running the mypy.
log_abs_det_jacobian
of many transforms have unused parameters. I have typed them as union ofUnusedParam
and some appropriate numpy/jax type.Many cases were unresolvable, like,
__eq__
method expectsbool
as return type, but&
operation between arrays return array of type bool, which conflicts with the return type, therefore I have added the tag to ignore them. You will find similar tags in the file.