diff --git a/kafka/conn.py b/kafka/conn.py index 4fdeb17c7..90dee4c72 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -59,6 +59,9 @@ def __init__(self, host, port, timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS): self.reinit() + def __getnewargs__(self): + return (self.host, self.port, self.timeout) + def __repr__(self): return "" % (self.host, self.port) @@ -135,6 +138,7 @@ def copy(self): """ c = copy.deepcopy(self) c._sock = None + c._dirty = True return c def close(self): diff --git a/kafka/producer.py b/kafka/producer.py index 12a293401..37409121b 100644 --- a/kafka/producer.py +++ b/kafka/producer.py @@ -3,10 +3,10 @@ import logging import time -from Queue import Empty +from Queue import Empty, Queue from collections import defaultdict from itertools import cycle -from multiprocessing import Queue, Process +from threading import Thread from kafka.common import ProduceRequest, TopicAndPartition from kafka.partitioner import HashedPartitioner @@ -26,13 +26,8 @@ def _send_upstream(queue, client, batch_time, batch_size, Listen on the queue for a specified number of messages or till a specified timeout and send them upstream to the brokers in one request - - NOTE: Ideally, this should have been a method inside the Producer - class. However, multiprocessing module has issues in windows. The - functionality breaks unless this function is kept outside of a class """ stop = False - client.reinit() while not stop: timeout = batch_time @@ -120,17 +115,17 @@ def __init__(self, client, async=False, if self.async: self.queue = Queue() # Messages are sent through this queue - self.proc = Process(target=_send_upstream, - args=(self.queue, - self.client.copy(), - batch_send_every_t, - batch_send_every_n, - self.req_acks, - self.ack_timeout)) - - # Process will die if main thread exits - self.proc.daemon = True - self.proc.start() + self.thread = Thread(target=_send_upstream, + args=(self.queue, + self.client.copy(), + batch_send_every_t, + batch_send_every_n, + self.req_acks, + self.ack_timeout)) + + # Thread will die if main thread exits + self.thread.daemon = True + self.thread.start() def send_messages(self, topic, partition, *msg): """ @@ -159,10 +154,7 @@ def stop(self, timeout=1): """ if self.async: self.queue.put((STOP_ASYNC_PRODUCER, None)) - self.proc.join(timeout) - - if self.proc.is_alive(): - self.proc.terminate() + self.thread.join(timeout) class SimpleProducer(Producer): diff --git a/test/test_unit.py b/test/test_unit.py index 8c0dd004f..1ed69358b 100644 --- a/test/test_unit.py +++ b/test/test_unit.py @@ -6,6 +6,7 @@ from mock import MagicMock, patch from kafka import KafkaClient +from kafka.conn import KafkaConnection from kafka.common import ( ProduceRequest, FetchRequest, Message, ChecksumError, ConsumerFetchSizeTooSmall, ProduceResponse, FetchResponse, @@ -670,5 +671,49 @@ def test_send_produce_request_raises_when_noleader(self, protocol, conn): LeaderUnavailableError, client.send_produce_request, requests) +class TestKafkaConnection(unittest.TestCase): + @patch('socket.socket') + def test_copy(self, socket): + """KafkaConnection copies work as expected""" + + conn = KafkaConnection('kafka', 9092) + self.assertEqual(socket.call_count, 1) + + copy = conn.copy() + self.assertEqual(socket.call_count, 1) + self.assertEqual(copy.host, 'kafka') + self.assertEqual(copy.port, 9092) + self.assertEqual(copy._sock, None) + self.assertEqual(copy._dirty, True) + + copy.reinit() + self.assertEqual(socket.call_count, 2) + self.assertNotEqual(copy._sock, None) + self.assertNotEqual(copy._dirty, True) + + @patch('socket.socket') + def test_copy_thread(self, socket): + """KafkaConnection copies work in other threads""" + + err = [] + copy = KafkaConnection('kafka', 9092).copy() + + from threading import Thread + def thread_func(err, copy): + try: + self.assertEqual(copy.host, 'kafka') + self.assertEqual(copy.port, 9092) + self.assertNotEqual(copy._sock, None) + except Exception, e: + err.append(e) + else: + err.append(None) + thread = Thread(target=thread_func, args=(err, copy)) + thread.start() + thread.join() + + self.assertEqual(err, [None]) + self.assertEqual(socket.call_count, 2) + if __name__ == '__main__': unittest.main()