Skip to content

Commit 80f815e

Browse files
committed
Add RemoteCallbacks.push_negotiation
1 parent 9efa238 commit 80f815e

File tree

4 files changed

+102
-4
lines changed

4 files changed

+102
-4
lines changed

pygit2/callbacks.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
from pygit2._libgit2.ffi import GitProxyOptionsC
8282

8383
from ._pygit2 import CloneOptions, PushOptions
84-
from .remotes import TransferProgress
84+
from .remotes import PushUpdate, TransferProgress
8585
#
8686
# The payload is the way to pass information from the pygit2 API, through
8787
# libgit2, to the Python callbacks. And back.
@@ -198,6 +198,15 @@ def certificate_check(self, certificate: None, valid: bool, host: bytes) -> bool
198198

199199
raise Passthrough
200200

201+
def push_negotiation(self, updates: list['PushUpdate']) -> None:
202+
"""
203+
During a push, called once between the negotiation step and the upload.
204+
Provides information about what updates will be performed.
205+
206+
Override with your own function to check the pending updates
207+
and possibly reject them (by raising an exception).
208+
"""
209+
201210
def transfer_progress(self, stats: 'TransferProgress') -> None:
202211
"""
203212
During the download of new data, this will be regularly called with
@@ -427,6 +436,7 @@ def git_push_options(payload, opts=None):
427436
opts.callbacks.credentials = C._credentials_cb
428437
opts.callbacks.certificate_check = C._certificate_check_cb
429438
opts.callbacks.push_update_reference = C._push_update_reference_cb
439+
opts.callbacks.push_negotiation = C._push_negotiation_cb
430440
# Per libgit2 sources, push_transfer_progress may incur a performance hit.
431441
# So, set it only if the user has overridden the no-op stub.
432442
if (
@@ -559,6 +569,19 @@ def _credentials_cb(cred_out, url, username, allowed, data):
559569
return 0
560570

561571

572+
@libgit2_callback
573+
def _push_negotiation_cb(updates, num_updates, data):
574+
from .remotes import PushUpdate
575+
576+
push_negotiation = getattr(data, 'push_negotiation', None)
577+
if not push_negotiation:
578+
return 0
579+
580+
py_updates = [PushUpdate(updates[i]) for i in range(num_updates)]
581+
push_negotiation(py_updates)
582+
return 0
583+
584+
562585
@libgit2_callback
563586
def _push_update_reference_cb(ref, msg, data):
564587
push_update_reference = getattr(data, 'push_update_reference', None)

pygit2/decl/callbacks.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ extern "Python" int _push_update_reference_cb(
1616
const char *status,
1717
void *data);
1818

19+
extern "Python" int _push_negotiation_cb(
20+
const git_push_update **updates,
21+
size_t len,
22+
void *data);
23+
1924
extern "Python" int _remote_create_cb(
2025
git_remote **out,
2126
git_repository *repo,

pygit2/remotes.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,34 @@ class LsRemotesDict(TypedDict):
5858
oid: Oid
5959

6060

61+
class PushUpdate:
62+
"""
63+
Represents an update which will be performed on the remote during push.
64+
"""
65+
66+
src_refname: str
67+
"""The source name of the reference"""
68+
69+
dst_refname: str
70+
"""The name of the reference to update on the server"""
71+
72+
src: Oid
73+
"""The current target of the reference"""
74+
75+
dst: Oid
76+
"""The new target for the reference"""
77+
78+
def __init__(self, c_struct: Any) -> None:
79+
src_refname = maybe_string(c_struct.src_refname)
80+
dst_refname = maybe_string(c_struct.dst_refname)
81+
assert src_refname is not None, 'libgit2 returned null src_refname'
82+
assert dst_refname is not None, 'libgit2 returned null dst_refname'
83+
self.src_refname = src_refname
84+
self.dst_refname = dst_refname
85+
self.src = Oid(raw=bytes(ffi.buffer(c_struct.src.id)[:]))
86+
self.dst = Oid(raw=bytes(ffi.buffer(c_struct.dst.id)[:]))
87+
88+
6189
class TransferProgress:
6290
"""Progress downloading and indexing data during a fetch."""
6391

test/test_remote.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
import pygit2
3333
from pygit2 import Remote, Repository
34-
from pygit2.remotes import TransferProgress
34+
from pygit2.remotes import PushUpdate, TransferProgress
3535

3636
from . import utils
3737

@@ -406,9 +406,12 @@ def push_transfer_progress(
406406
assert origin.branches['master'].target == new_tip_id
407407

408408

409+
@pytest.mark.parametrize('reject_from', ['push_transfer_progress', 'push_negotiation'])
409410
def test_push_interrupted_from_callbacks(
410-
origin: Repository, clone: Repository, remote: Remote
411+
origin: Repository, clone: Repository, remote: Remote, reject_from: str
411412
) -> None:
413+
reject_message = 'retreat! retreat!'
414+
412415
tip = clone[clone.head.target]
413416
clone.create_commit(
414417
'refs/heads/master',
@@ -420,10 +423,15 @@ def test_push_interrupted_from_callbacks(
420423
)
421424

422425
class MyCallbacks(pygit2.RemoteCallbacks):
426+
def push_negotiation(self, updates: list[PushUpdate]) -> None:
427+
if reject_from == 'push_negotiation':
428+
raise InterruptedError(reject_message)
429+
423430
def push_transfer_progress(
424431
self, objects_pushed: int, total_objects: int, bytes_pushed: int
425432
) -> None:
426-
raise InterruptedError('retreat! retreat!')
433+
if reject_from == 'push_transfer_progress':
434+
raise InterruptedError(reject_message)
427435

428436
assert origin.branches['master'].target == tip.id
429437

@@ -504,3 +512,37 @@ def test_push_threads(origin: Repository, clone: Repository, remote: Remote) ->
504512
callbacks = RemoteCallbacks()
505513
remote.push(['refs/heads/master'], callbacks, threads=1)
506514
assert callbacks.push_options.pb_parallelism == 1
515+
516+
517+
def test_push_negotiation(
518+
origin: Repository, clone: Repository, remote: Remote
519+
) -> None:
520+
old_tip = clone[clone.head.target]
521+
new_tip_id = clone.create_commit(
522+
'refs/heads/master',
523+
old_tip.author,
524+
old_tip.author,
525+
'empty commit',
526+
old_tip.tree.id,
527+
[old_tip.id],
528+
)
529+
530+
the_updates: list[PushUpdate] = []
531+
532+
class MyCallbacks(pygit2.RemoteCallbacks):
533+
def push_negotiation(self, updates: list[PushUpdate]) -> None:
534+
the_updates.extend(updates)
535+
536+
assert origin.branches['master'].target == old_tip.id
537+
assert 'new_branch' not in origin.branches
538+
539+
callbacks = MyCallbacks()
540+
remote.push(['refs/heads/master'], callbacks=callbacks)
541+
542+
assert len(the_updates) == 1
543+
assert the_updates[0].src_refname == 'refs/heads/master'
544+
assert the_updates[0].dst_refname == 'refs/heads/master'
545+
assert the_updates[0].src == old_tip.id
546+
assert the_updates[0].dst == new_tip_id
547+
548+
assert origin.branches['master'].target == new_tip_id

0 commit comments

Comments
 (0)