-
Notifications
You must be signed in to change notification settings - Fork 528
[MRG] OT barycenters for generic transport costs #715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #715 +/- ##
==========================================
+ Coverage 97.10% 97.14% +0.04%
==========================================
Files 100 100
Lines 20453 20849 +396
==========================================
+ Hits 19861 20254 +393
- Misses 592 595 +3 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @eloitanguy this is very nice work as usual. I have a couple comments below
ot/lp/_barycenter_solvers.py
Outdated
List of K arrays of measure weights, each of shape (m_k). | ||
X_init : array-like | ||
Array of shape (n, d) representing initial barycenter points. | ||
cost_list : list of callable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you accept a callable (if same cost for eveyone) and a list?
After discussion with @rflamary we decided to expand the PR with additional features, namely two barycenter solvers:
|
After some updates to the paper behind these algorithms (the paper update is soon to come), I implemented another barycenter solver which corresponds exactly to the method studied theoretically in the paper (iterates of G). I updated this PR with the additional algorithm ( On my end, this contribution is ready for review :D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this very nice PR @eloitanguy ! :)
Follows some comments on the generic barycentre solver for now. I'll revise the GMM parts next.
Barycentre method: 'L2_barycentric_proj' (default) for Euclidean | ||
barycentric projection, or 'true_fixed_point' for iterates using the | ||
North West Corner multi-marginal gluing method. | ||
stopThr : float, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specify the criterion
ground_bary_numItermax : int, optional | ||
Maximum number of iterations for the ground barycenter solver (if auto | ||
is used). | ||
ground_bary_stopThr : float, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
specify the criterion.
emd(a, measure_weights[k], cost_list[k](X, measure_locations[k])) | ||
for k in range(K) | ||
] | ||
Y_perm = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to remove for consistency
if method == "L2_barycentric_proj": | ||
a_next = a # barycentre weights are fixed | ||
for k in range(K): # L2 barycentric projection of pi_k | ||
Y_perm.append((1 / a[:, None]) * pi_list[k] @ measure_locations[k]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Y_perm = [... for k in range(K)]
raise ImportError("PyTorch is required to use ground_bary=None") | ||
|
||
X_list = [X_init] if log else [] # store the iterations | ||
a_list = [nx.copy(a)] if log and method == "true_fixed_point" else [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy needed ?
m_k}`. If cost_list is a single callable, the same cost is used K times. | ||
ground_bary : callable or None, optional | ||
Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays | ||
of shape (n\times d_K), computing the ground barycenters (broadcasted |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
math ?
(non-linear) projection onto a circle k, and :math:`(\lambda_k)` are weights. A | ||
barycenter is defined ([76]) as a minimiser of the energy :math:`V(\mu) = \sum_k | ||
\mathcal{T}_{c_k}(\mu, \nu_k)` where :math:`\mu` is a candidate barycenter | ||
measure, the measures :math:`\nu_k` are the target measures and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rendering issue ?
|
||
torch.manual_seed(42) | ||
|
||
n = 136 # number of points of the of the barycentre |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove of the
n = 136 # number of points of the of the barycentre | ||
d = 2 # dimensions of the original measure | ||
K = 4 # number of measures to barycentre | ||
m = 50 # number of points of the measures |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we take different number of points ?
# -*- coding: utf-8 -*- | ||
""" | ||
===================================== | ||
OT Barycenter with Generic Costs Demo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice illustrations :)
Types of changes
free_support
that accepts any cost function (implements this paper)ot.gmm
for fast computation of GMM barycentersREADME.md
For the (theoretical) fixed-point method, use
method='true_fixed_point'
inot.lp.free_support_barycenter_generic_costs
and for the barycentric heuristic, usemethod='L2_barycentric_proj'
. The latter is the default, given the computational advantages and the desirable property of keeping a fixed support size.Motivation and context / Related issue
How has this been tested (if it applies)
test/test_ot.py
andtest/test_gmm.py
PR checklist