12
12
from kafka .protocol import CODEC_NONE
13
13
14
14
import threading
15
- import multiprocessing as mp
16
15
try :
17
- from queue import Empty
16
+ from queue import Empty , Queue
18
17
except ImportError :
19
- from Queue import Empty
18
+ from Queue import Empty , Queue
20
19
21
20
22
21
class TestKafkaProducer (unittest .TestCase ):
@@ -56,33 +55,26 @@ def partitions(topic):
56
55
class TestKafkaProducerSendUpstream (unittest .TestCase ):
57
56
58
57
def setUp (self ):
59
-
60
- # create a multiprocessing Value to store call counter
61
- # (magicmock counters don't work with other processes)
62
- self .send_calls_count = mp .Value ('i' , 0 )
63
-
64
- def send_side_effect (* args , ** kwargs ):
65
- self .send_calls_count .value += 1
66
-
67
58
self .client = MagicMock ()
68
- self .client .send_produce_request .side_effect = send_side_effect
69
- self .queue = mp .Queue ()
59
+ self .queue = Queue ()
70
60
71
61
def _run_process (self , retries_limit = 3 , sleep_timeout = 1 ):
72
62
# run _send_upstream process with the queue
73
- self .process = mp .Process (
63
+ stop_event = threading .Event ()
64
+ self .thread = threading .Thread (
74
65
target = _send_upstream ,
75
66
args = (self .queue , self .client , CODEC_NONE ,
76
67
0.3 , # batch time (seconds)
77
68
3 , # batch length
78
69
Producer .ACK_AFTER_LOCAL_WRITE ,
79
70
Producer .DEFAULT_ACK_TIMEOUT ,
80
71
50 , # retry backoff (ms)
81
- retries_limit ))
82
- self .process .daemon = True
83
- self .process .start ()
72
+ retries_limit ,
73
+ stop_event ))
74
+ self .thread .daemon = True
75
+ self .thread .start ()
84
76
time .sleep (sleep_timeout )
85
- self . process . terminate ()
77
+ stop_event . set ()
86
78
87
79
def test_wo_retries (self ):
88
80
@@ -97,7 +89,8 @@ def test_wo_retries(self):
97
89
98
90
# there should be 4 non-void cals:
99
91
# 3 batches of 3 msgs each + 1 batch of 1 message
100
- self .assertEqual (self .send_calls_count .value , 4 )
92
+ self .assertEqual (self .client .send_produce_request .call_count , 4 )
93
+
101
94
102
95
def test_first_send_failed (self ):
103
96
@@ -106,11 +99,10 @@ def test_first_send_failed(self):
106
99
for i in range (10 ):
107
100
self .queue .put ((TopicAndPartition ("test" , i ), "msg %i" , "key %i" ))
108
101
109
- is_first_time = mp . Value ( 'b' , True )
102
+ self . client . is_first_time = True
110
103
def send_side_effect (reqs , * args , ** kwargs ):
111
- self .send_calls_count .value += 1
112
- if is_first_time .value :
113
- is_first_time .value = False
104
+ if self .client .is_first_time :
105
+ self .client .is_first_time = False
114
106
raise FailedPayloadsError (reqs )
115
107
116
108
self .client .send_produce_request .side_effect = send_side_effect
@@ -122,7 +114,7 @@ def send_side_effect(reqs, *args, **kwargs):
122
114
123
115
# there should be 5 non-void cals: 1st failed batch of 3 msgs
124
116
# + 3 batches of 3 msgs each + 1 batch of 1 msg = 1 + 3 + 1 = 5
125
- self .assertEqual (self .send_calls_count . value , 5 )
117
+ self .assertEqual (self .client . send_produce_request . call_count , 5 )
126
118
127
119
def test_with_limited_retries (self ):
128
120
@@ -132,7 +124,6 @@ def test_with_limited_retries(self):
132
124
self .queue .put ((TopicAndPartition ("test" , i ), "msg %i" , "key %i" ))
133
125
134
126
def send_side_effect (reqs , * args , ** kwargs ):
135
- self .send_calls_count .value += 1
136
127
raise FailedPayloadsError (reqs )
137
128
138
129
self .client .send_produce_request .side_effect = send_side_effect
@@ -145,8 +136,7 @@ def send_side_effect(reqs, *args, **kwargs):
145
136
# there should be 16 non-void cals:
146
137
# 3 initial batches of 3 msgs each + 1 initial batch of 1 msg +
147
138
# 3 retries of the batches above = 4 + 3 * 4 = 16, all failed
148
- self .assertEqual (self .send_calls_count .value , 16 )
149
-
139
+ self .assertEqual (self .client .send_produce_request .call_count , 16 )
150
140
151
141
def test_with_unlimited_retries (self ):
152
142
@@ -156,7 +146,6 @@ def test_with_unlimited_retries(self):
156
146
self .queue .put ((TopicAndPartition ("test" , i ), "msg %i" , "key %i" ))
157
147
158
148
def send_side_effect (reqs , * args , ** kwargs ):
159
- self .send_calls_count .value += 1
160
149
raise FailedPayloadsError (reqs )
161
150
162
151
self .client .send_produce_request .side_effect = send_side_effect
@@ -174,5 +163,5 @@ def send_side_effect(reqs, *args, **kwargs):
174
163
self .assertEqual (self .queue .empty (), True )
175
164
176
165
# 1s / 50ms of backoff = 20 times max
177
- self .assertTrue ( self . send_calls_count . value > 10 )
178
- self .assertTrue (self . send_calls_count . value <= 20 )
166
+ calls = self .client . send_produce_request . call_count
167
+ self .assertTrue (calls > 10 & calls <= 20 )
0 commit comments