Skip to content

Commit 9c97d09

Browse files
committed
WorkerLogRecord class
1 parent e4ae220 commit 9c97d09

File tree

4 files changed

+124
-103
lines changed

4 files changed

+124
-103
lines changed

src/torchrunx/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from .launcher import Launcher, launch
2+
from .logging_utils import add_filter_to_handler, file_handler, stream_handler
23

34
__all__ = [
45
"Launcher",
56
"launch",
7+
"add_filter_to_handler",
8+
"file_handler",
9+
"stream_handler",
610
]

src/torchrunx/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerExce
6767

6868
redirect_stdio_to_logger(logger)
6969

70-
store = dist.TCPStore( # pyright: ignore[reportPrivateImportUsage]
70+
store = dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
7171
host_name=worker_args.main_agent_hostname,
7272
port=worker_args.main_agent_port,
7373
world_size=worker_args.world_size,

src/torchrunx/logging_utils.py

Lines changed: 118 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,115 @@
22

33
import datetime
44
import logging
5-
import os # noqa: TCH003
65
import pickle
76
import struct
87
from contextlib import redirect_stderr, redirect_stdout
8+
from dataclasses import dataclass
99
from io import StringIO
1010
from logging import Handler, Logger
1111
from logging.handlers import SocketHandler
1212
from pathlib import Path
1313
from socketserver import StreamRequestHandler, ThreadingTCPServer
14+
from typing import TYPE_CHECKING
15+
16+
from typing_extensions import Self
17+
18+
if TYPE_CHECKING:
19+
import os
20+
21+
## Launcher utilities
22+
23+
24+
class LogRecordSocketReceiver(ThreadingTCPServer):
25+
def __init__(self, host: str, port: int, handlers: list[Handler]) -> None:
26+
self.host = host
27+
self.port = port
28+
29+
class _LogRecordStreamHandler(StreamRequestHandler):
30+
def handle(self) -> None:
31+
while True:
32+
chunk_size = 4
33+
chunk = self.connection.recv(chunk_size)
34+
if len(chunk) < chunk_size:
35+
break
36+
slen = struct.unpack(">L", chunk)[0]
37+
chunk = self.connection.recv(slen)
38+
while len(chunk) < slen:
39+
chunk = chunk + self.connection.recv(slen - len(chunk))
40+
obj = pickle.loads(chunk)
41+
record = logging.makeLogRecord(obj)
42+
43+
for handler in handlers:
44+
handler.handle(record)
45+
46+
super().__init__(
47+
server_address=(host, port),
48+
RequestHandlerClass=_LogRecordStreamHandler,
49+
bind_and_activate=True,
50+
)
51+
self.daemon_threads = True
52+
53+
def shutdown(self) -> None:
54+
"""override BaseServer.shutdown() with added timeout"""
55+
self._BaseServer__shutdown_request = True
56+
self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue]
57+
58+
59+
## Agent/worker utilities
60+
61+
62+
@dataclass
63+
class WorkerLogRecord(logging.LogRecord):
64+
hostname: str
65+
worker_rank: int | None
66+
67+
@classmethod
68+
def from_record(cls, record: logging.LogRecord, hostname: str, worker_rank: int | None) -> Self:
69+
record.hostname = hostname
70+
record.worker_rank = worker_rank
71+
record.__class__ = cls
72+
return record # pyright: ignore [reportReturnType]
73+
74+
75+
def log_records_to_socket(
76+
logger: Logger,
77+
hostname: str,
78+
worker_rank: int | None,
79+
logger_hostname: str,
80+
logger_port: int,
81+
) -> None:
82+
logger.setLevel(logging.NOTSET)
83+
84+
old_factory = logging.getLogRecordFactory()
85+
86+
def record_factory(*args, **kwargs) -> WorkerLogRecord: # noqa: ANN002, ANN003
87+
record = old_factory(*args, **kwargs)
88+
return WorkerLogRecord.from_record(record, hostname, worker_rank)
89+
90+
logging.setLogRecordFactory(record_factory)
91+
92+
logger.addHandler(SocketHandler(host=logger_hostname, port=logger_port))
93+
94+
95+
def redirect_stdio_to_logger(logger: Logger) -> None:
96+
class _LoggingStream(StringIO):
97+
def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None:
98+
super().__init__()
99+
self.logger = logger
100+
self.level = level
101+
102+
def flush(self) -> None:
103+
super().flush()
104+
value = self.getvalue()
105+
if value != "":
106+
self.logger.log(self.level, value)
107+
self.truncate(0)
108+
self.seek(0)
109+
110+
logging.captureWarnings(capture=True)
111+
redirect_stderr(_LoggingStream(logger, level=logging.ERROR)).__enter__()
112+
redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__()
113+
14114

15115
## Handler utilities
16116

@@ -21,14 +121,27 @@ def add_filter_to_handler(
21121
worker_rank: int | None,
22122
log_level: int = logging.NOTSET,
23123
) -> None:
24-
def _filter(record: logging.LogRecord) -> bool:
124+
def _filter(record: WorkerLogRecord) -> bool:
25125
return (
26-
record.hostname == hostname # pyright: ignore[reportAttributeAccessIssue]
27-
and record.worker_rank == worker_rank # pyright: ignore[reportAttributeAccessIssue]
126+
record.hostname == hostname
127+
and record.worker_rank == worker_rank
28128
and record.levelno >= log_level
29129
)
30130

31-
handler.addFilter(_filter)
131+
handler.addFilter(_filter) # pyright: ignore [reportArgumentType]
132+
133+
134+
def stream_handler(hostname: str, rank: int | None, log_level: int = logging.NOTSET) -> Handler:
135+
handler = logging.StreamHandler()
136+
add_filter_to_handler(handler, hostname, rank, log_level=log_level)
137+
handler.setFormatter(
138+
logging.Formatter(
139+
"%(asctime)s:%(levelname)s:%(hostname)s[%(worker_rank)s]: %(message)s"
140+
if rank is not None
141+
else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s",
142+
),
143+
)
144+
return handler
32145

33146

34147
def file_handler(
@@ -67,19 +180,6 @@ def file_handlers(
67180
return handlers
68181

69182

70-
def stream_handler(hostname: str, rank: int | None, log_level: int = logging.NOTSET) -> Handler:
71-
handler = logging.StreamHandler()
72-
add_filter_to_handler(handler, hostname, rank, log_level=log_level)
73-
handler.setFormatter(
74-
logging.Formatter(
75-
"%(asctime)s:%(levelname)s:%(hostname)s[%(worker_rank)s]: %(message)s"
76-
if rank is not None
77-
else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s",
78-
),
79-
)
80-
return handler
81-
82-
83183
def default_handlers(
84184
hostnames: list[str],
85185
workers_per_host: list[int],
@@ -91,86 +191,3 @@ def default_handlers(
91191
stream_handler(hostname=hostnames[0], rank=0, log_level=log_level),
92192
*file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level),
93193
]
94-
95-
96-
## Agent/worker utilities
97-
98-
99-
def log_records_to_socket(
100-
logger: Logger,
101-
hostname: str,
102-
worker_rank: int | None,
103-
logger_hostname: str,
104-
logger_port: int,
105-
) -> None:
106-
logger.setLevel(logging.NOTSET)
107-
108-
old_factory = logging.getLogRecordFactory()
109-
110-
def record_factory(*args, **kwargs) -> logging.LogRecord: # noqa: ANN002, ANN003
111-
record = old_factory(*args, **kwargs)
112-
record.hostname = hostname
113-
record.worker_rank = worker_rank
114-
return record
115-
116-
logging.setLogRecordFactory(record_factory)
117-
118-
logger.addHandler(SocketHandler(host=logger_hostname, port=logger_port))
119-
120-
121-
def redirect_stdio_to_logger(logger: Logger) -> None:
122-
class _LoggingStream(StringIO):
123-
def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None:
124-
super().__init__()
125-
self.logger = logger
126-
self.level = level
127-
128-
def flush(self) -> None:
129-
super().flush()
130-
value = self.getvalue()
131-
if value != "":
132-
self.logger.log(self.level, value)
133-
self.truncate(0)
134-
self.seek(0)
135-
136-
logging.captureWarnings(capture=True)
137-
redirect_stderr(_LoggingStream(logger, level=logging.ERROR)).__enter__()
138-
redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__()
139-
140-
141-
## Launcher utilities
142-
143-
144-
class LogRecordSocketReceiver(ThreadingTCPServer):
145-
def __init__(self, host: str, port: int, handlers: list[Handler]) -> None:
146-
self.host = host
147-
self.port = port
148-
149-
class _LogRecordStreamHandler(StreamRequestHandler):
150-
def handle(self) -> None:
151-
while True:
152-
chunk_size = 4
153-
chunk = self.connection.recv(chunk_size)
154-
if len(chunk) < chunk_size:
155-
break
156-
slen = struct.unpack(">L", chunk)[0]
157-
chunk = self.connection.recv(slen)
158-
while len(chunk) < slen:
159-
chunk = chunk + self.connection.recv(slen - len(chunk))
160-
obj = pickle.loads(chunk)
161-
record = logging.makeLogRecord(obj)
162-
163-
for handler in handlers:
164-
handler.handle(record)
165-
166-
super().__init__(
167-
server_address=(host, port),
168-
RequestHandlerClass=_LogRecordStreamHandler,
169-
bind_and_activate=True,
170-
)
171-
self.daemon_threads = True
172-
173-
def shutdown(self) -> None:
174-
"""override BaseServer.shutdown() with added timeout"""
175-
self._BaseServer__shutdown_request = True
176-
self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue]

src/torchrunx/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __post_init__(self) -> None:
7777
backend="gloo",
7878
world_size=self.world_size,
7979
rank=self.rank,
80-
store=dist.TCPStore( # pyright: ignore[reportPrivateImportUsage]
80+
store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
8181
host_name=self.launcher_hostname,
8282
port=self.launcher_port,
8383
world_size=self.world_size,

0 commit comments

Comments
 (0)