@@ -5,6 +5,7 @@ package tarantool
5
5
import (
6
6
"bufio"
7
7
"bytes"
8
+ "context"
8
9
"errors"
9
10
"fmt"
10
11
"io"
@@ -143,16 +144,57 @@ type Connection struct {
143
144
144
145
var _ = Connector (& Connection {}) // Check compatibility with connector interface.
145
146
147
+ type futureList struct {
148
+ first * Future
149
+ last * * Future
150
+ }
151
+
152
+ func (list * futureList ) findFuture (reqid uint32 , fetch bool ) * Future {
153
+ root := & list .first
154
+ for {
155
+ fut := * root
156
+ if fut == nil {
157
+ return nil
158
+ }
159
+ if fut .requestId == reqid {
160
+ if fetch {
161
+ * root = fut .next
162
+ if fut .next == nil {
163
+ list .last = root
164
+ } else {
165
+ fut .next = nil
166
+ }
167
+ }
168
+ return fut
169
+ }
170
+ root = & fut .next
171
+ }
172
+ }
173
+
174
+ func (list * futureList ) addFuture (fut * Future ) {
175
+ * list .last = fut
176
+ list .last = & fut .next
177
+ }
178
+
179
+ func (list * futureList ) clear (err error , conn * Connection ) {
180
+ fut := list .first
181
+ list .first = nil
182
+ list .last = & list .first
183
+ for fut != nil {
184
+ fut .SetError (err )
185
+ conn .markDone (fut )
186
+ fut , fut .next = fut .next , nil
187
+ }
188
+ }
189
+
146
190
type connShard struct {
147
- rmut sync.Mutex
148
- requests [requestsMap ]struct {
149
- first * Future
150
- last * * Future
151
- }
152
- bufmut sync.Mutex
153
- buf smallWBuf
154
- enc * msgpack.Encoder
155
- _pad [16 ]uint64 //nolint: unused,structcheck
191
+ rmut sync.Mutex
192
+ requests [requestsMap ]futureList
193
+ requestsWithCtx [requestsMap ]futureList
194
+ bufmut sync.Mutex
195
+ buf smallWBuf
196
+ enc * msgpack.Encoder
197
+ _pad [16 ]uint64 //nolint: unused,structcheck
156
198
}
157
199
158
200
// Greeting is a message sent by Tarantool on connect.
@@ -286,6 +328,9 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) {
286
328
for j := range shard .requests {
287
329
shard .requests [j ].last = & shard .requests [j ].first
288
330
}
331
+ for j := range shard .requests {
332
+ shard .requestsWithCtx [j ].last = & shard .requestsWithCtx [j ].first
333
+ }
289
334
}
290
335
291
336
if opts .RateLimit > 0 {
@@ -387,6 +432,17 @@ func (conn *Connection) Handle() interface{} {
387
432
return conn .opts .Handle
388
433
}
389
434
435
+ func (conn * Connection ) cancelFuture (fut * Future , err error ) error {
436
+ if fut == nil {
437
+ return fmt .Errorf ("passed nil future" )
438
+ }
439
+ if fut = conn .fetchFuture (fut .requestId ); fut != nil {
440
+ fut .SetError (err )
441
+ conn .markDone (fut )
442
+ }
443
+ return nil
444
+ }
445
+
390
446
func (conn * Connection ) dial () (err error ) {
391
447
var connection net.Conn
392
448
network := "tcp"
@@ -582,14 +638,11 @@ func (conn *Connection) closeConnection(neterr error, forever bool) (err error)
582
638
conn .shard [i ].buf .Reset ()
583
639
requests := & conn .shard [i ].requests
584
640
for pos := range requests {
585
- fut := requests [pos ].first
586
- requests [pos ].first = nil
587
- requests [pos ].last = & requests [pos ].first
588
- for fut != nil {
589
- fut .SetError (neterr )
590
- conn .markDone (fut )
591
- fut , fut .next = fut .next , nil
592
- }
641
+ requests [pos ].clear (neterr , conn )
642
+ }
643
+ requestsWithCtx := & conn .shard [i ].requestsWithCtx
644
+ for pos := range requestsWithCtx {
645
+ requestsWithCtx [pos ].clear (neterr , conn )
593
646
}
594
647
}
595
648
return
@@ -721,7 +774,7 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) {
721
774
}
722
775
}
723
776
724
- func (conn * Connection ) newFuture () (fut * Future ) {
777
+ func (conn * Connection ) newFuture (ctx context. Context ) (fut * Future ) {
725
778
fut = NewFuture ()
726
779
if conn .rlimit != nil && conn .opts .RLimitAction == RLimitDrop {
727
780
select {
@@ -761,11 +814,20 @@ func (conn *Connection) newFuture() (fut *Future) {
761
814
return
762
815
}
763
816
pos := (fut .requestId / conn .opts .Concurrency ) & (requestsMap - 1 )
764
- pair := & shard .requests [pos ]
765
- * pair .last = fut
766
- pair .last = & fut .next
767
- if conn .opts .Timeout > 0 {
768
- fut .timeout = time .Since (epoch ) + conn .opts .Timeout
817
+ if ctx != nil {
818
+ select {
819
+ case <- ctx .Done ():
820
+ fut .SetError (fmt .Errorf ("context is done" ))
821
+ shard .rmut .Unlock ()
822
+ return
823
+ default :
824
+ }
825
+ shard .requestsWithCtx [pos ].addFuture (fut )
826
+ } else {
827
+ shard .requests [pos ].addFuture (fut )
828
+ if conn .opts .Timeout > 0 {
829
+ fut .timeout = time .Since (epoch ) + conn .opts .Timeout
830
+ }
769
831
}
770
832
shard .rmut .Unlock ()
771
833
if conn .rlimit != nil && conn .opts .RLimitAction == RLimitWait {
@@ -785,12 +847,40 @@ func (conn *Connection) newFuture() (fut *Future) {
785
847
return
786
848
}
787
849
850
+ func (conn * Connection ) contextWatchdog (fut * Future , ctx context.Context ) {
851
+ select {
852
+ case <- fut .done :
853
+ default :
854
+ select {
855
+ case <- ctx .Done ():
856
+ conn .cancelFuture (fut , fmt .Errorf ("context is done" ))
857
+ default :
858
+ select {
859
+ case <- fut .done :
860
+ case <- ctx .Done ():
861
+ conn .cancelFuture (fut , fmt .Errorf ("context is done" ))
862
+ }
863
+ }
864
+ }
865
+ }
866
+
788
867
func (conn * Connection ) send (req Request ) * Future {
789
- fut := conn .newFuture ()
868
+ fut := conn .newFuture (req . Context () )
790
869
if fut .ready == nil {
791
870
return fut
792
871
}
872
+ if req .Context () != nil {
873
+ select {
874
+ case <- req .Context ().Done ():
875
+ conn .cancelFuture (fut , fmt .Errorf ("context is done" ))
876
+ return fut
877
+ default :
878
+ }
879
+ }
793
880
conn .putFuture (fut , req )
881
+ if req .Context () != nil {
882
+ go conn .contextWatchdog (fut , req .Context ())
883
+ }
794
884
return fut
795
885
}
796
886
@@ -877,26 +967,11 @@ func (conn *Connection) fetchFuture(reqid uint32) (fut *Future) {
877
967
func (conn * Connection ) getFutureImp (reqid uint32 , fetch bool ) * Future {
878
968
shard := & conn .shard [reqid & (conn .opts .Concurrency - 1 )]
879
969
pos := (reqid / conn .opts .Concurrency ) & (requestsMap - 1 )
880
- pair := & shard .requests [pos ]
881
- root := & pair .first
882
- for {
883
- fut := * root
884
- if fut == nil {
885
- return nil
886
- }
887
- if fut .requestId == reqid {
888
- if fetch {
889
- * root = fut .next
890
- if fut .next == nil {
891
- pair .last = root
892
- } else {
893
- fut .next = nil
894
- }
895
- }
896
- return fut
897
- }
898
- root = & fut .next
970
+ fut := shard .requests [pos ].findFuture (reqid , fetch )
971
+ if fut == nil {
972
+ fut = shard .requestsWithCtx [pos ].findFuture (reqid , fetch )
899
973
}
974
+ return fut
900
975
}
901
976
902
977
func (conn * Connection ) timeouts () {
@@ -1000,6 +1075,15 @@ func (conn *Connection) Do(req Request) *Future {
1000
1075
return fut
1001
1076
}
1002
1077
}
1078
+ if req .Context () != nil {
1079
+ select {
1080
+ case <- req .Context ().Done ():
1081
+ fut := NewFuture ()
1082
+ fut .SetError (fmt .Errorf ("context is done" ))
1083
+ return fut
1084
+ default :
1085
+ }
1086
+ }
1003
1087
return conn .send (req )
1004
1088
}
1005
1089
0 commit comments