Skip to content

Commit 1c1cccd

Browse files
committed
rafthttp: stop etcd if it is found removed when stream dial
The original process is stopping etcd only when pipeline message finds itself has been removed. After this PR, stream dial has this functionality too. It helps fast etcd stop, which doesn't need to wait for stream break to fall back to pipeline, and wait for election timeout to send out message to detect self removal.
1 parent be6f49b commit 1c1cccd

File tree

9 files changed

+64
-20
lines changed

9 files changed

+64
-20
lines changed

etcdserver/server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ func (s *EtcdServer) Process(ctx context.Context, m raftpb.Message) error {
348348
return s.r.Step(ctx, m)
349349
}
350350

351+
func (s *EtcdServer) IsIDRemoved(id uint64) bool { return s.Cluster.IsIDRemoved(types.ID(id)) }
352+
351353
func (s *EtcdServer) ReportUnreachable(id uint64) { s.r.ReportUnreachable(id) }
352354

353355
func (s *EtcdServer) ReportSnapshot(id uint64, status raft.SnapshotStatus) {

rafthttp/functional_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,9 @@ func waitStreamWorking(p *peer) bool {
134134
}
135135

136136
type fakeRaft struct {
137-
recvc chan<- raftpb.Message
138-
err error
137+
recvc chan<- raftpb.Message
138+
err error
139+
removedID uint64
139140
}
140141

141142
func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
@@ -146,6 +147,8 @@ func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
146147
return p.err
147148
}
148149

150+
func (p *fakeRaft) IsIDRemoved(id uint64) bool { return id == p.removedID }
151+
149152
func (p *fakeRaft) ReportUnreachable(id uint64) {}
150153

151154
func (p *fakeRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}

rafthttp/http.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ type peerGetter interface {
4646
Get(id types.ID) Peer
4747
}
4848

49-
func newStreamHandler(peerGetter peerGetter, id, cid types.ID) http.Handler {
49+
func newStreamHandler(peerGetter peerGetter, r Raft, id, cid types.ID) http.Handler {
5050
return &streamHandler{
5151
peerGetter: peerGetter,
52+
r: r,
5253
id: id,
5354
cid: cid,
5455
}
@@ -112,6 +113,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
112113

113114
type streamHandler struct {
114115
peerGetter peerGetter
116+
r Raft
115117
id types.ID
116118
cid types.ID
117119
}
@@ -145,6 +147,11 @@ func (h *streamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
145147
http.Error(w, "invalid from", http.StatusNotFound)
146148
return
147149
}
150+
if h.r.IsIDRemoved(uint64(from)) {
151+
log.Printf("rafthttp: reject the stream from peer %s since it was removed", from)
152+
http.Error(w, "removed member", http.StatusGone)
153+
return
154+
}
148155
p := h.peerGetter.Get(from)
149156
if p == nil {
150157
log.Printf("rafthttp: fail to find sender %s", from)

rafthttp/http_test.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package rafthttp
1717
import (
1818
"bytes"
1919
"errors"
20+
"fmt"
2021
"io"
2122
"net/http"
2223
"net/http/httptest"
@@ -185,7 +186,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
185186

186187
peer := newFakePeer()
187188
peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): peer}}
188-
h := newStreamHandler(peerGetter, types.ID(2), types.ID(1))
189+
h := newStreamHandler(peerGetter, &fakeRaft{}, types.ID(2), types.ID(1))
189190

190191
rw := httptest.NewRecorder()
191192
go h.ServeHTTP(rw, req)
@@ -207,6 +208,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
207208
}
208209

209210
func TestServeRaftStreamPrefixBad(t *testing.T) {
211+
removedID := uint64(5)
210212
tests := []struct {
211213
method string
212214
path string
@@ -263,6 +265,14 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
263265
"1",
264266
http.StatusNotFound,
265267
},
268+
// removed peer
269+
{
270+
"GET",
271+
RaftStreamPrefix + "/message/" + fmt.Sprint(removedID),
272+
"1",
273+
"1",
274+
http.StatusGone,
275+
},
266276
// wrong cluster ID
267277
{
268278
"GET",
@@ -289,7 +299,8 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
289299
req.Header.Set("X-Raft-To", tt.remote)
290300
rw := httptest.NewRecorder()
291301
peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): newFakePeer()}}
292-
h := newStreamHandler(peerGetter, types.ID(1), types.ID(1))
302+
r := &fakeRaft{removedID: removedID}
303+
h := newStreamHandler(peerGetter, r, types.ID(1), types.ID(1))
293304
h.ServeHTTP(rw, req)
294305

295306
if rw.Code != tt.wcode {

rafthttp/peer.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r
149149

150150
go func() {
151151
var paused bool
152-
msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc)
153-
reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc)
152+
msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc, errorc)
153+
reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc, errorc)
154154
for {
155155
select {
156156
case m := <-p.sendc:

rafthttp/stream.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ type streamReader struct {
226226
cid types.ID
227227
recvc chan<- raftpb.Message
228228
propc chan<- raftpb.Message
229+
errorc chan<- error
229230

230231
mu sync.Mutex
231232
msgAppTerm uint64
@@ -235,7 +236,7 @@ type streamReader struct {
235236
done chan struct{}
236237
}
237238

238-
func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message) *streamReader {
239+
func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader {
239240
r := &streamReader{
240241
tr: tr,
241242
picker: picker,
@@ -245,6 +246,7 @@ func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, fr
245246
cid: cid,
246247
recvc: recvc,
247248
propc: propc,
249+
errorc: errorc,
248250
stopc: make(chan struct{}),
249251
done: make(chan struct{}),
250252
}
@@ -367,11 +369,21 @@ func (cr *streamReader) dial() (io.ReadCloser, error) {
367369
cr.picker.unreachable(u)
368370
return nil, fmt.Errorf("error roundtripping to %s: %v", req.URL, err)
369371
}
370-
if resp.StatusCode != http.StatusOK {
372+
switch resp.StatusCode {
373+
case http.StatusGone:
374+
resp.Body.Close()
375+
err := fmt.Errorf("the member has been permanently removed from the cluster")
376+
select {
377+
case cr.errorc <- err:
378+
default:
379+
}
380+
return nil, err
381+
case http.StatusOK:
382+
return resp.Body, nil
383+
default:
371384
resp.Body.Close()
372385
return nil, fmt.Errorf("unhandled http status %d", resp.StatusCode)
373386
}
374-
return resp.Body, nil
375387
}
376388

377389
func (cr *streamReader) cancelRequest() {

rafthttp/stream_test.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,17 @@ func TestStreamReaderDialRequest(t *testing.T) {
119119
// HTTP response received.
120120
func TestStreamReaderDialResult(t *testing.T) {
121121
tests := []struct {
122-
code int
123-
err error
124-
wok bool
122+
code int
123+
err error
124+
wok bool
125+
whalt bool
125126
}{
126-
{0, errors.New("blah"), false},
127-
{http.StatusOK, nil, true},
128-
{http.StatusMethodNotAllowed, nil, false},
129-
{http.StatusNotFound, nil, false},
130-
{http.StatusPreconditionFailed, nil, false},
127+
{0, errors.New("blah"), false, false},
128+
{http.StatusOK, nil, true, false},
129+
{http.StatusMethodNotAllowed, nil, false, false},
130+
{http.StatusNotFound, nil, false, false},
131+
{http.StatusPreconditionFailed, nil, false, false},
132+
{http.StatusGone, nil, false, true},
131133
}
132134
for i, tt := range tests {
133135
tr := newRespRoundTripper(tt.code, tt.err)
@@ -138,12 +140,16 @@ func TestStreamReaderDialResult(t *testing.T) {
138140
from: types.ID(1),
139141
to: types.ID(2),
140142
cid: types.ID(1),
143+
errorc: make(chan error, 1),
141144
}
142145

143146
_, err := sr.dial()
144147
if ok := err == nil; ok != tt.wok {
145148
t.Errorf("#%d: ok = %v, want %v", i, ok, tt.wok)
146149
}
150+
if halt := len(sr.errorc) > 0; halt != tt.whalt {
151+
t.Errorf("#%d: halt = %v, want %v", i, halt, tt.whalt)
152+
}
147153
}
148154
}
149155

@@ -203,7 +209,7 @@ func TestStream(t *testing.T) {
203209
h.sw = sw
204210

205211
picker := mustNewURLPicker(t, []string{srv.URL})
206-
sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc)
212+
sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc, nil)
207213
defer sr.stop()
208214
if tt.t == streamTypeMsgApp {
209215
sr.updateMsgAppTerm(tt.term)

rafthttp/transport.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828

2929
type Raft interface {
3030
Process(ctx context.Context, m raftpb.Message) error
31+
IsIDRemoved(id uint64) bool
3132
ReportUnreachable(id uint64)
3233
ReportSnapshot(id uint64, status raft.SnapshotStatus)
3334
}
@@ -98,7 +99,7 @@ func NewTransporter(rt http.RoundTripper, id, cid types.ID, r Raft, errorc chan
9899

99100
func (t *transport) Handler() http.Handler {
100101
pipelineHandler := NewHandler(t.raft, t.clusterID)
101-
streamHandler := newStreamHandler(t, t.id, t.clusterID)
102+
streamHandler := newStreamHandler(t, t.raft, t.id, t.clusterID)
102103
mux := http.NewServeMux()
103104
mux.Handle(RaftPrefix, pipelineHandler)
104105
mux.Handle(RaftStreamPrefix+"/", streamHandler)

rafthttp/transport_bench_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ func (r *countRaft) Process(ctx context.Context, m raftpb.Message) error {
8888
return nil
8989
}
9090

91+
func (r *countRaft) IsIDRemoved(id uint64) bool { return false }
92+
9193
func (r *countRaft) ReportUnreachable(id uint64) {}
9294

9395
func (r *countRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}

0 commit comments

Comments
 (0)