-
Notifications
You must be signed in to change notification settings - Fork 65
ProjectTo{<:Tangent}
for tuples & Ref
#488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
17a941e
8b8314e
1d75133
ecd318f
15756a4
3e5cda1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -274,16 +274,64 @@ end | |
# Ref | ||
function ProjectTo(x::Ref) | ||
sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)? | ||
if sub isa ProjectTo{<:AbstractZero} | ||
return ProjectTo{Tangent{typeof(x)}}(; x=sub) | ||
end | ||
|
||
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(only(backing(dx)))) | ||
function (project::ProjectTo{<:Tangent{<:Ref}})(dx::Ref) | ||
dy = project.x(dx[]) | ||
return project_type(project)(; x=dy) | ||
end | ||
# Since this works like a zero-array in broadcasting, it should also accept a number: | ||
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx)) | ||
|
||
# Tuple | ||
function ProjectTo(x::Tuple) | ||
elements = map(ProjectTo, x) | ||
if elements isa NTuple{<:Any, ProjectTo{<:AbstractZero}} | ||
return ProjectTo{NoTangent}() | ||
else | ||
return ProjectTo{Ref}(; type=typeof(x), x=sub) | ||
return ProjectTo{Tangent{typeof(x)}}(; elements=elements) | ||
end | ||
end | ||
(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x=project.x(dx.x)) | ||
(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[])) | ||
# Since this works like a zero-array in broadcasting, it should also accept a number: | ||
(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx)) | ||
|
||
(project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tangent) = project(backing(dx)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This means that projection on the output of this projector will disassemble the Tangent and re-process the Tuple inside. I'm not sure that's ideal. Maybe it's safe to pass on all Tangents without further investigation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should constant-fold out for most cases we care about. Do you want to check some with It is not safe to pass on all Tangents, because the tangent could be wrapping Complex Number/ Dense array that we need to fix. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I wonder is whether I can think of the "mathematical" steps involving arrays etc. as being separate from the "structural" steps involving Tangents. If the first project, and then the backward flow assembles and de-assembles a Tangent, can this Tangent have "crossed a boundary" such that it belongs to a different argument type and hence may need further projection? I mostly think it can't; it would have to get un-packaed and those pieces operated on. But I'm not very sure. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. But it gets fuzzy around the edges? |
||
function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple) | ||
len = length(project.elements) | ||
if length(dx) != len | ||
str = "tuple with length(x) == $len cannot have a gradient with length(dx) == $(length(dx))" | ||
throw(DimensionMismatch(str)) | ||
end | ||
Comment on lines
+305
to
+308
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will this not be caught by If we removed this check then this would basically be the general iterator fallback case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will, but the error is much less friendly... and might be a bug, JuliaLang/julia#42216 |
||
dy = map((f,x) -> f(x), project.elements, dx) | ||
mcabbott marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
return project_type(project)(dy...) | ||
end | ||
function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray) | ||
for d in 1:ndims(dx) | ||
if size(dx, d) != get(length(project.elements), d, 1) | ||
throw(_projection_mismatch(axes(project.elements), size(dx))) | ||
end | ||
end | ||
dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray | ||
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) | ||
return project_type(project)(dz...) | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The above Tuple and AbstractArray cases are just optimizations of a general iterator one: function (project::ProjectTo{<:Tangent{<:Tuple}})(dxs) # iterator fallback
dzs = (f(dx) for (f, dx) in zip(project.elements, dxs))
return project_type(project)(dzs...)
end Should we have that as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could do. Do you have something in mind which might produce some weird type? If some NamedTuple leaks from Zygote, I think this will produce stranger error messages, since it may make a Tangent of the wrong length? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really have it in mind, more is that that is the general case we are handling. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess I'm not entirely sure we need the array version at all. I was thinking about things like broadcasting, although that handles it explicitly... but map doesn't:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would be down with seeing it removed til we know we need it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My vote is to keep this, although I can't think of another example besides |
||
(project::ProjectTo{<:Tangent{<:Tuple}})(dx::Number) = project(tuple(dx)) | ||
mcabbott marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
#= | ||
# NamedTuple | ||
function ProjectTo(x::NamedTuple) | ||
elements = map(ProjectTo, x) | ||
if all(p -> p isa ProjectTo{<:AbstractZero}, elements) | ||
return ProjectTo{NoTangent}() | ||
else | ||
return ProjectTo{Tangent{typeof(x)}}(; elements=elements) | ||
end | ||
end | ||
function (project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::Union{Tuple,NamedTuple}) | ||
dy = map((f,x) -> f(x), project.elements, dx) | ||
return project_type(project)(dy...) | ||
end | ||
=# | ||
mcabbott marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
##### | ||
##### `LinearAlgebra` | ||
|
Uh oh!
There was an error while loading. Please reload this page.