diff --git a/google/cloud/logging_v2/client.py b/google/cloud/logging_v2/client.py index 1b5beeb24..f52845ee5 100644 --- a/google/cloud/logging_v2/client.py +++ b/google/cloud/logging_v2/client.py @@ -149,6 +149,8 @@ def __init__( else: self._use_grpc = _use_grpc + self._handlers = set() + @property def logging_api(self): """Helper for logging-related API calls. @@ -411,4 +413,17 @@ def setup_logging( dict: keyword args passed to handler constructor """ handler = self.get_default_handler(**kw) + self._handlers.add(handler) setup_logging(handler, log_level=log_level, excluded_loggers=excluded_loggers) + + def flush_handlers(self): + """Flushes all Python log handlers associated with this Client.""" + + for handler in self._handlers: + handler.flush() + + def close(self): + """Closes the Client and all handlers associated with this Client.""" + super(Client, self).close() + for handler in self._handlers: + handler.close() diff --git a/google/cloud/logging_v2/handlers/handlers.py b/google/cloud/logging_v2/handlers/handlers.py index e71f673f7..364246d58 100644 --- a/google/cloud/logging_v2/handlers/handlers.py +++ b/google/cloud/logging_v2/handlers/handlers.py @@ -188,7 +188,10 @@ def __init__( resource = detect_resource(client.project) self.name = name self.client = client + client._handlers.add(self) self.transport = transport(client, name, resource=resource) + self._transport_open = True + self._transport_cls = transport self.project_id = client.project self.resource = resource self.labels = labels @@ -213,6 +216,12 @@ def emit(self, record): labels = {**add_resource_labels(resource, record), **(labels or {})} or None # send off request + if not self._transport_open: + self.transport = self._transport_cls( + self.client, self.name, resource=self.resource + ) + self._transport_open = True + self.transport.send( record, message, @@ -225,6 +234,21 @@ def emit(self, record): source_location=record._source_location, ) + def flush(self): + """Forces the Transport object to submit any pending log records. + + For SyncTransport, this is a no-op. + """ + super(CloudLoggingHandler, self).flush() + if self._transport_open: + self.transport.flush() + + def close(self): + """Closes the log handler and cleans up all Transport objects used.""" + self.transport.close() + self.transport = None + self._transport_open = False + def _format_and_parse_message(record, formatter_handler): """ diff --git a/google/cloud/logging_v2/handlers/transports/background_thread.py b/google/cloud/logging_v2/handlers/transports/background_thread.py index 7cf2799f5..021112fdb 100644 --- a/google/cloud/logging_v2/handlers/transports/background_thread.py +++ b/google/cloud/logging_v2/handlers/transports/background_thread.py @@ -38,6 +38,13 @@ _WORKER_TERMINATOR = object() _LOGGER = logging.getLogger(__name__) +_CLOSE_THREAD_SHUTDOWN_ERROR_MSG = ( + "CloudLoggingHandler shutting down, cannot send logs entries to Cloud Logging due to " + "inconsistent threading behavior at shutdown. To avoid this issue, flush the logging handler " + "manually or switch to StructuredLogHandler. You can also close the CloudLoggingHandler manually " + "via handler.close or client.close." +) + def _get_many(queue_, *, max_items=None, max_latency=0): """Get multiple items from a Queue. @@ -140,9 +147,11 @@ def _thread_main(self): else: batch.log(**item) - self._safely_commit_batch(batch) + # We cannot commit logs upstream if the main thread is shutting down + if threading.main_thread().is_alive(): + self._safely_commit_batch(batch) - for _ in items: + for it in items: self._queue.task_done() _LOGGER.debug("Background thread exited gracefully.") @@ -162,7 +171,7 @@ def start(self): ) self._thread.daemon = True self._thread.start() - atexit.register(self._main_thread_terminated) + atexit.register(self._handle_exit) def stop(self, *, grace_period=None): """Signals the background thread to stop. @@ -202,26 +211,26 @@ def stop(self, *, grace_period=None): return success - def _main_thread_terminated(self): - """Callback that attempts to send pending logs before termination.""" + def _close(self, close_msg): + """Callback that attempts to send pending logs before termination if the main thread is alive.""" if not self.is_alive: return if not self._queue.empty(): - print( - "Program shutting down, attempting to send %d queued log " - "entries to Cloud Logging..." % (self._queue.qsize(),), - file=sys.stderr, - ) + print(close_msg, file=sys.stderr) - if self.stop(grace_period=self._grace_period): + if threading.main_thread().is_alive() and self.stop( + grace_period=self._grace_period + ): print("Sent all pending logs.", file=sys.stderr) - else: + elif not self._queue.empty(): print( "Failed to send %d pending logs." % (self._queue.qsize(),), file=sys.stderr, ) + self._thread = None + def enqueue(self, record, message, **kwargs): """Queues a log entry to be written by the background thread. @@ -251,6 +260,26 @@ def flush(self): """Submit any pending log records.""" self._queue.join() + def close(self): + """Signals the worker thread to stop, then closes the transport thread. + + This call will attempt to send pending logs before termination, and + should be followed up by disowning the transport object. + """ + atexit.unregister(self._handle_exit) + self._close( + "Background thread shutting down, attempting to send %d queued log " + "entries to Cloud Logging..." % (self._queue.qsize(),) + ) + + def _handle_exit(self): + """Handle system exit. + + Since we cannot send pending logs during system shutdown due to thread errors, + log an error message to stderr to notify the user. + """ + self._close(_CLOSE_THREAD_SHUTDOWN_ERROR_MSG) + class BackgroundThreadTransport(Transport): """Asynchronous transport that uses a background thread.""" @@ -285,6 +314,7 @@ def __init__( """ self.client = client logger = self.client.logger(name, resource=resource) + self.grace_period = grace_period self.worker = _Worker( logger, grace_period=grace_period, @@ -307,3 +337,7 @@ def send(self, record, message, **kwargs): def flush(self): """Submit any pending log records.""" self.worker.flush() + + def close(self): + """Closes the worker thread.""" + self.worker.close() diff --git a/google/cloud/logging_v2/handlers/transports/base.py b/google/cloud/logging_v2/handlers/transports/base.py index a0c9aafa4..31e8f418a 100644 --- a/google/cloud/logging_v2/handlers/transports/base.py +++ b/google/cloud/logging_v2/handlers/transports/base.py @@ -51,3 +51,11 @@ def flush(self): For blocking/sync transports, this is a no-op. """ + pass + + def close(self): + """Closes the transport and cleans up resources used by it. + + This call should be followed up by disowning the transport. + """ + pass diff --git a/google/cloud/logging_v2/handlers/transports/sync.py b/google/cloud/logging_v2/handlers/transports/sync.py index 17a4e554e..6bf91f8da 100644 --- a/google/cloud/logging_v2/handlers/transports/sync.py +++ b/google/cloud/logging_v2/handlers/transports/sync.py @@ -59,3 +59,10 @@ def send(self, record, message, **kwargs): labels=labels, **kwargs, ) + + def close(self): + """Closes the transport and cleans up resources used by it. + + This call is usually followed up by cleaning up the reference to the transport. + """ + self.logger = None diff --git a/tests/system/test_system.py b/tests/system/test_system.py index d4ec4da36..487ecde62 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -34,6 +34,7 @@ import google.cloud.logging from google.cloud._helpers import UTC from google.cloud.logging_v2.handlers import CloudLoggingHandler +from google.cloud.logging_v2.handlers.transports import BackgroundThreadTransport from google.cloud.logging_v2.handlers.transports import SyncTransport from google.cloud.logging_v2 import client from google.cloud.logging_v2.resource import Resource @@ -719,6 +720,72 @@ def test_log_handler_otel_integration(self): self.assertEqual(entries[0].span_id, expected_span_id) self.assertTrue(entries[0].trace_sampled, expected_tracesampled) + def test_log_handler_close(self): + from multiprocessing import Process + + LOG_MESSAGE = "This is a test of handler.close before exiting." + LOGGER_NAME = "close-test" + handler_name = self._logger_name(LOGGER_NAME) + + # only create the logger to delete, hidden otherwise + logger = Config.CLIENT.logger(handler_name) + self.to_delete.append(logger) + + # Run a simulation of logging an entry then immediately shutting down. + # The .close() function before the process exits should prevent the + # thread shutdown error and let us log the message. + def subprocess_main(): + # logger.delete and logger.list_entries work by filtering on log name, so we + # can create new objects with the same name and have the queries on the parent + # process still work. + handler = CloudLoggingHandler( + Config.CLIENT, name=handler_name, transport=BackgroundThreadTransport + ) + cloud_logger = logging.getLogger(LOGGER_NAME) + cloud_logger.addHandler(handler) + cloud_logger.warning(LOG_MESSAGE) + handler.close() + + proc = Process(target=subprocess_main) + proc.start() + proc.join() + entries = _list_entries(logger) + self.assertEqual(len(entries), 1) + self.assertEqual(entries[0].payload, LOG_MESSAGE) + + def test_log_client_flush_handlers(self): + from multiprocessing import Process + + LOG_MESSAGE = "This is a test of client.flush_handlers before exiting." + LOGGER_NAME = "close-test" + handler_name = self._logger_name(LOGGER_NAME) + + # only create the logger to delete, hidden otherwise + logger = Config.CLIENT.logger(handler_name) + self.to_delete.append(logger) + + # Run a simulation of logging an entry then immediately shutting down. + # The .close() function before the process exits should prevent the + # thread shutdown error and let us log the message. + def subprocess_main(): + # logger.delete and logger.list_entries work by filtering on log name, so we + # can create new objects with the same name and have the queries on the parent + # process still work. + handler = CloudLoggingHandler( + Config.CLIENT, name=handler_name, transport=BackgroundThreadTransport + ) + cloud_logger = logging.getLogger(LOGGER_NAME) + cloud_logger.addHandler(handler) + cloud_logger.warning(LOG_MESSAGE) + Config.CLIENT.flush_handlers() + + proc = Process(target=subprocess_main) + proc.start() + proc.join() + entries = _list_entries(logger) + self.assertEqual(len(entries), 1) + self.assertEqual(entries[0].payload, LOG_MESSAGE) + def test_create_metric(self): METRIC_NAME = "test-create-metric%s" % (_RESOURCE_ID,) metric = Config.CLIENT.metric( diff --git a/tests/unit/handlers/test_handlers.py b/tests/unit/handlers/test_handlers.py index 14b2e5cba..2e9484937 100644 --- a/tests/unit/handlers/test_handlers.py +++ b/tests/unit/handlers/test_handlers.py @@ -461,6 +461,7 @@ def test_ctor_defaults(self): self.assertEqual(handler.name, DEFAULT_LOGGER_NAME) self.assertIs(handler.client, client) self.assertIsInstance(handler.transport, _Transport) + self.assertTrue(handler._transport_open) self.assertIs(handler.transport.client, client) self.assertEqual(handler.transport.name, DEFAULT_LOGGER_NAME) global_resource = _create_global_resource(self.PROJECT) @@ -468,6 +469,17 @@ def test_ctor_defaults(self): self.assertIsNone(handler.labels) self.assertIs(handler.stream, sys.stderr) + def test_add_handler_to_client_handlers(self): + from google.cloud.logging_v2.logger import _GLOBAL_RESOURCE + + client = _Client(self.PROJECT) + handler = self._make_one( + client, + transport=_Transport, + resource=_GLOBAL_RESOURCE, + ) + self.assertIn(handler, client._handlers) + def test_ctor_explicit(self): import io from google.cloud.logging import Resource @@ -790,6 +802,56 @@ def test_emit_with_encoded_json(self): ), ) + def test_emit_after_close(self): + from google.cloud.logging_v2.logger import _GLOBAL_RESOURCE + + client = _Client(self.PROJECT) + handler = self._make_one( + client, transport=_Transport, resource=_GLOBAL_RESOURCE + ) + logname = "loggername" + message = "hello world" + record = logging.LogRecord( + logname, logging.INFO, None, None, message, None, None + ) + handler.handle(record) + old_transport = handler.transport + self.assertEqual( + handler.transport.send_called_with, + ( + record, + message, + _GLOBAL_RESOURCE, + {"python_logger": logname}, + None, + None, + False, + None, + None, + ), + ) + + handler.close() + self.assertFalse(handler._transport_open) + + handler.handle(record) + self.assertTrue(handler._transport_open) + self.assertNotEqual(handler.transport, old_transport) + self.assertEqual( + handler.transport.send_called_with, + ( + record, + message, + _GLOBAL_RESOURCE, + {"python_logger": logname}, + None, + None, + False, + None, + None, + ), + ) + def test_format_with_arguments(self): """ Handler should support format string arguments @@ -825,6 +887,20 @@ def test_format_with_arguments(self): ), ) + def test_close(self): + from google.cloud.logging_v2.logger import _GLOBAL_RESOURCE + + client = _Client(self.PROJECT) + handler = self._make_one( + client, + transport=_Transport, + resource=_GLOBAL_RESOURCE, + ) + old_transport = handler.transport + handler.close() + self.assertFalse(handler._transport_open) + self.assertTrue(old_transport.close_called) + class TestFormatAndParseMessage(unittest.TestCase): def test_none(self): @@ -1127,12 +1203,14 @@ def release(self): class _Client(object): def __init__(self, project): self.project = project + self._handlers = set() class _Transport(object): def __init__(self, client, name, resource=None): self.client = client self.name = name + self.close_called = False def send( self, @@ -1157,3 +1235,6 @@ def send( http_request, source_location, ) + + def close(self): + self.close_called = True diff --git a/tests/unit/handlers/transports/test_background_thread.py b/tests/unit/handlers/transports/test_background_thread.py index d4954ff7b..9fdccb172 100644 --- a/tests/unit/handlers/transports/test_background_thread.py +++ b/tests/unit/handlers/transports/test_background_thread.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import time import logging import queue +import re import unittest import mock +from io import StringIO + class TestBackgroundThreadHandler(unittest.TestCase): PROJECT = "PROJECT" @@ -176,6 +180,11 @@ def test_worker(self): class Test_Worker(unittest.TestCase): NAME = "python_logger" + def setUp(self): + import sys + + print("In method", self._testMethodName, file=sys.stderr) + @staticmethod def _get_target_class(): from google.cloud.logging_v2.handlers.transports import background_thread @@ -187,9 +196,26 @@ def _make_one(self, *args, **kw): def _start_with_thread_patch(self, worker): with mock.patch("threading.Thread", new=_Thread) as thread_mock: - with mock.patch("atexit.register") as atexit_mock: - worker.start() - return thread_mock, atexit_mock + worker.start() + return thread_mock + + @staticmethod + @contextlib.contextmanager + def _init_atexit_mock(): + atexit_mock = _AtexitMock() + with mock.patch.multiple( + "atexit", register=atexit_mock.register, unregister=atexit_mock.unregister + ): + yield atexit_mock + + @staticmethod + @contextlib.contextmanager + def _init_main_thread_is_alive_mock(is_alive): + with mock.patch("threading.main_thread") as main_thread_func_mock: + main_thread_obj_mock = mock.Mock() + main_thread_func_mock.return_value = main_thread_obj_mock + main_thread_obj_mock.is_alive = mock.Mock(return_value=is_alive) + yield def test_constructor(self): logger = _Logger(self.NAME) @@ -216,14 +242,15 @@ def test_start(self): worker = self._make_one(_Logger(self.NAME)) - _, atexit_mock = self._start_with_thread_patch(worker) + with self._init_atexit_mock() as atexit_mock: + self._start_with_thread_patch(worker) self.assertTrue(worker.is_alive) self.assertIsNotNone(worker._thread) self.assertTrue(worker._thread.daemon) self.assertEqual(worker._thread._target, worker._thread_main) self.assertEqual(worker._thread._name, background_thread._WORKER_THREAD_NAME) - atexit_mock.assert_called_once_with(worker._main_thread_terminated) + self.assertIn(worker._handle_exit, atexit_mock.registered_funcs) # Calling start again should not start a new thread. current_thread = worker._thread @@ -260,29 +287,33 @@ def test_stop_no_grace(self): self.assertEqual(thread._timeout, None) - def test__main_thread_terminated(self): + def test__close(self): worker = self._make_one(_Logger(self.NAME)) self._start_with_thread_patch(worker) - worker._main_thread_terminated() + worker._close("") self.assertFalse(worker.is_alive) # Calling twice should not be an error - worker._main_thread_terminated() + worker._close("") - def test__main_thread_terminated_non_empty_queue(self): + def test__close_non_empty_queue(self): worker = self._make_one(_Logger(self.NAME)) + msg = "My Message" self._start_with_thread_patch(worker) record = mock.Mock() record.created = time.time() worker.enqueue(record, "") - worker._main_thread_terminated() + + with mock.patch("sys.stderr", new_callable=StringIO) as stderr_mock: + worker._close(msg) + self.assertIn(msg, stderr_mock.getvalue()) self.assertFalse(worker.is_alive) - def test__main_thread_terminated_did_not_join(self): + def test__close_did_not_join(self): worker = self._make_one(_Logger(self.NAME)) self._start_with_thread_patch(worker) @@ -290,7 +321,65 @@ def test__main_thread_terminated_did_not_join(self): record = mock.Mock() record.created = time.time() worker.enqueue(record, "") - worker._main_thread_terminated() + worker._close("") + + self.assertFalse(worker.is_alive) + + def test__handle_exit(self): + from google.cloud.logging_v2.handlers.transports.background_thread import ( + _CLOSE_THREAD_SHUTDOWN_ERROR_MSG, + ) + + worker = self._make_one(_Logger(self.NAME)) + + with mock.patch("sys.stderr", new_callable=StringIO) as stderr_mock: + with self._init_main_thread_is_alive_mock(False): + with self._init_atexit_mock(): + self._start_with_thread_patch(worker) + self._enqueue_record(worker, "test") + worker._handle_exit() + + self.assertRegex( + stderr_mock.getvalue(), + re.compile("^%s$" % _CLOSE_THREAD_SHUTDOWN_ERROR_MSG, re.MULTILINE), + ) + + self.assertRegex( + stderr_mock.getvalue(), + re.compile( + r"^Failed to send %d pending logs\.$" % worker._queue.qsize(), + re.MULTILINE, + ), + ) + + def test__handle_exit_no_items(self): + worker = self._make_one(_Logger(self.NAME)) + + with mock.patch("sys.stderr", new_callable=StringIO) as stderr_mock: + with self._init_main_thread_is_alive_mock(False): + with self._init_atexit_mock(): + self._start_with_thread_patch(worker) + worker._handle_exit() + + self.assertEqual(stderr_mock.getvalue(), "") + + def test_close_unregister_atexit(self): + worker = self._make_one(_Logger(self.NAME)) + + with mock.patch("sys.stderr", new_callable=StringIO) as stderr_mock: + with self._init_atexit_mock() as atexit_mock: + self._start_with_thread_patch(worker) + self.assertIn(worker._handle_exit, atexit_mock.registered_funcs) + worker.close() + self.assertNotIn(worker._handle_exit, atexit_mock.registered_funcs) + + self.assertNotRegex( + stderr_mock.getvalue(), + re.compile( + r"^Failed to send %d pending logs\.$" % worker._queue.qsize(), + re.MULTILINE, + ), + ) self.assertFalse(worker.is_alive) @@ -402,6 +491,23 @@ def test__thread_main_batches(self): self.assertFalse(worker._cloud_logger._batch.commit_called) self.assertEqual(worker._queue.qsize(), 0) + def test__thread_main_main_thread_terminated(self): + from google.cloud.logging_v2.handlers.transports import background_thread + + worker = self._make_one(_Logger(self.NAME)) + self._enqueue_record(worker, "1") + worker._queue.put_nowait(background_thread._WORKER_TERMINATOR) + + with mock.patch("threading.main_thread") as main_thread_func_mock: + main_thread_obj_mock = mock.Mock() + main_thread_func_mock.return_value = main_thread_obj_mock + main_thread_obj_mock.is_alive = mock.Mock(return_value=False) + self._enqueue_record(worker, "1") + self._enqueue_record(worker, "2") + worker._thread_main() + + self.assertFalse(worker._cloud_logger._batch.commit_called) + @mock.patch("time.time", autospec=True, return_value=1) def test__thread_main_max_latency(self, time): # Note: this test is a bit brittle as it assumes the operation of @@ -565,3 +671,16 @@ def __init__(self, project, _http=None, credentials=None): def logger(self, name, resource=None): # pylint: disable=unused-argument self._logger = _Logger(name, resource=resource) return self._logger + + +class _AtexitMock(object): + """_AtexitMock is a simulation of registering/unregistering functions in atexit using a dummy set.""" + + def __init__(self): + self.registered_funcs = set() + + def register(self, func): + self.registered_funcs.add(func) + + def unregister(self, func): + self.registered_funcs.remove(func) diff --git a/tests/unit/handlers/transports/test_base.py b/tests/unit/handlers/transports/test_base.py index a0013cadf..b723db87b 100644 --- a/tests/unit/handlers/transports/test_base.py +++ b/tests/unit/handlers/transports/test_base.py @@ -38,3 +38,7 @@ def test_resource_is_valid_argunent(self): def test_flush_is_abstract_and_optional(self): target = self._make_one("client", "name") target.flush() + + def test_close_is_abstract_and_optional(self): + target = self._make_one("client", "name") + target.close() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2d12a283e..6a9a7fd84 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -842,6 +842,7 @@ def test_setup_logging(self): (handler,) = args self.assertIsInstance(handler, CloudLoggingHandler) + self.assertIn(handler, client._handlers) handler.transport.worker.stop() @@ -882,6 +883,7 @@ def test_setup_logging_w_extra_kwargs(self): self.assertEqual(handler.name, name) self.assertEqual(handler.resource, resource) self.assertEqual(handler.labels, labels) + self.assertIn(handler, client._handlers) handler.transport.worker.stop() @@ -929,6 +931,168 @@ def test_setup_logging_w_extra_kwargs_structured_log(self): "log_level": 20, } self.assertEqual(kwargs, expected_kwargs) + self.assertIn(handler, client._handlers) + + def test_flush_handlers_cloud_logging_handler(self): + import io + from google.cloud.logging.handlers import CloudLoggingHandler + from google.cloud.logging import Resource + + name = "test-logger" + resource = Resource("resource_type", {"resource_label": "value"}) + labels = {"handler_label": "value"} + stream = io.BytesIO() + + credentials = _make_credentials() + client = self._make_one( + project=self.PROJECT, credentials=credentials, _use_grpc=False + ) + + with mock.patch("google.cloud.logging_v2.client.setup_logging") as mocked: + client.setup_logging( + name=name, resource=resource, labels=labels, stream=stream + ) + + self.assertEqual(len(mocked.mock_calls), 1) + _, args, kwargs = mocked.mock_calls[0] + + (handler,) = args + self.assertIsInstance(handler, CloudLoggingHandler) + + handler.flush = mock.Mock() + client.flush_handlers() + handler.flush.assert_called_once_with() + + def test_flush_handlers_cloud_logging_handler_no_setup_logging(self): + from google.cloud.logging.handlers import CloudLoggingHandler + + credentials = _make_credentials() + client = self._make_one( + project=self.PROJECT, credentials=credentials, _use_grpc=False + ) + + handler = CloudLoggingHandler(client) + self.assertIn(handler, client._handlers) + + handler.flush = mock.Mock() + client.flush_handlers() + handler.flush.assert_called_once_with() + + def test_flush_handlers_structured_log(self): + import io + from google.cloud.logging.handlers import StructuredLogHandler + from google.cloud.logging import Resource + from google.cloud.logging_v2.client import _GKE_RESOURCE_TYPE + + name = "test-logger" + resource = Resource(_GKE_RESOURCE_TYPE, {"resource_label": "value"}) + labels = {"handler_label": "value"} + stream = io.BytesIO() + + credentials = _make_credentials() + client = self._make_one( + project=self.PROJECT, credentials=credentials, _use_grpc=False + ) + + with mock.patch("google.cloud.logging_v2.client.setup_logging") as mocked: + client.setup_logging( + name=name, resource=resource, labels=labels, stream=stream + ) + + self.assertEqual(len(mocked.mock_calls), 1) + _, args, kwargs = mocked.mock_calls[0] + + (handler,) = args + self.assertIsInstance(handler, StructuredLogHandler) + + handler.flush = mock.Mock() + client.flush_handlers() + handler.flush.assert_called_once_with() + + def test_close_cloud_logging_handler(self): + import contextlib + import io + from google.cloud.logging.handlers import CloudLoggingHandler + from google.cloud.logging import Resource + + name = "test-logger" + resource = Resource("resource_type", {"resource_label": "value"}) + labels = {"handler_label": "value"} + stream = io.BytesIO() + + credentials = _make_credentials() + client = self._make_one( + project=self.PROJECT, credentials=credentials, _use_grpc=False + ) + + with mock.patch("google.cloud.logging_v2.client.setup_logging") as mocked: + client.setup_logging( + name=name, resource=resource, labels=labels, stream=stream + ) + + self.assertEqual(len(mocked.mock_calls), 1) + _, args, kwargs = mocked.mock_calls[0] + + (handler,) = args + self.assertIsInstance(handler, CloudLoggingHandler) + + handler.close = mock.Mock() + with contextlib.closing(client): + pass + + handler.close.assert_called_once_with() + + def test_close_cloud_logging_handler_no_setup_logging(self): + import contextlib + from google.cloud.logging.handlers import CloudLoggingHandler + + credentials = _make_credentials() + client = self._make_one( + project=self.PROJECT, credentials=credentials, _use_grpc=False + ) + + handler = CloudLoggingHandler(client) + self.assertIn(handler, client._handlers) + + handler.close = mock.Mock() + with contextlib.closing(client): + pass + + handler.close.assert_called_once_with() + + def test_close_structured_log_handler(self): + import contextlib + import io + from google.cloud.logging.handlers import StructuredLogHandler + from google.cloud.logging import Resource + from google.cloud.logging_v2.client import _GKE_RESOURCE_TYPE + + name = "test-logger" + resource = Resource(_GKE_RESOURCE_TYPE, {"resource_label": "value"}) + labels = {"handler_label": "value"} + stream = io.BytesIO() + + credentials = _make_credentials() + client = self._make_one( + project=self.PROJECT, credentials=credentials, _use_grpc=False + ) + + with mock.patch("google.cloud.logging_v2.client.setup_logging") as mocked: + client.setup_logging( + name=name, resource=resource, labels=labels, stream=stream + ) + + self.assertEqual(len(mocked.mock_calls), 1) + _, args, kwargs = mocked.mock_calls[0] + + (handler,) = args + self.assertIsInstance(handler, StructuredLogHandler) + + handler.close = mock.Mock() + with contextlib.closing(client): + pass + + handler.close.assert_called_once_with() class _Connection(object):