2
2
3
3
import datetime
4
4
import logging
5
- import os # noqa: TCH003
6
5
import pickle
7
6
import struct
8
7
from contextlib import redirect_stderr , redirect_stdout
8
+ from dataclasses import dataclass
9
9
from io import StringIO
10
10
from logging import Handler , Logger
11
11
from logging .handlers import SocketHandler
12
12
from pathlib import Path
13
13
from socketserver import StreamRequestHandler , ThreadingTCPServer
14
+ from typing import TYPE_CHECKING
15
+
16
+ from typing_extensions import Self
17
+
18
+ if TYPE_CHECKING :
19
+ import os
20
+
21
+ ## Launcher utilities
22
+
23
+
24
+ class LogRecordSocketReceiver (ThreadingTCPServer ):
25
+ def __init__ (self , host : str , port : int , handlers : list [Handler ]) -> None :
26
+ self .host = host
27
+ self .port = port
28
+
29
+ class _LogRecordStreamHandler (StreamRequestHandler ):
30
+ def handle (self ) -> None :
31
+ while True :
32
+ chunk_size = 4
33
+ chunk = self .connection .recv (chunk_size )
34
+ if len (chunk ) < chunk_size :
35
+ break
36
+ slen = struct .unpack (">L" , chunk )[0 ]
37
+ chunk = self .connection .recv (slen )
38
+ while len (chunk ) < slen :
39
+ chunk = chunk + self .connection .recv (slen - len (chunk ))
40
+ obj = pickle .loads (chunk )
41
+ record = logging .makeLogRecord (obj )
42
+
43
+ for handler in handlers :
44
+ handler .handle (record )
45
+
46
+ super ().__init__ (
47
+ server_address = (host , port ),
48
+ RequestHandlerClass = _LogRecordStreamHandler ,
49
+ bind_and_activate = True ,
50
+ )
51
+ self .daemon_threads = True
52
+
53
+ def shutdown (self ) -> None :
54
+ """override BaseServer.shutdown() with added timeout"""
55
+ self ._BaseServer__shutdown_request = True
56
+ self ._BaseServer__is_shut_down .wait (timeout = 3 ) # pyright: ignore[reportAttributeAccessIssue]
57
+
58
+
59
+ ## Agent/worker utilities
60
+
61
+
62
+ @dataclass
63
+ class WorkerLogRecord (logging .LogRecord ):
64
+ hostname : str
65
+ worker_rank : int | None
66
+
67
+ @classmethod
68
+ def from_record (cls , record : logging .LogRecord , hostname : str , worker_rank : int | None ) -> Self :
69
+ record .hostname = hostname
70
+ record .worker_rank = worker_rank
71
+ record .__class__ = cls
72
+ return record # pyright: ignore [reportReturnType]
73
+
74
+
75
+ def log_records_to_socket (
76
+ logger : Logger ,
77
+ hostname : str ,
78
+ worker_rank : int | None ,
79
+ logger_hostname : str ,
80
+ logger_port : int ,
81
+ ) -> None :
82
+ logger .setLevel (logging .NOTSET )
83
+
84
+ old_factory = logging .getLogRecordFactory ()
85
+
86
+ def record_factory (* args , ** kwargs ) -> WorkerLogRecord : # noqa: ANN002, ANN003
87
+ record = old_factory (* args , ** kwargs )
88
+ return WorkerLogRecord .from_record (record , hostname , worker_rank )
89
+
90
+ logging .setLogRecordFactory (record_factory )
91
+
92
+ logger .addHandler (SocketHandler (host = logger_hostname , port = logger_port ))
93
+
94
+
95
+ def redirect_stdio_to_logger (logger : Logger ) -> None :
96
+ class _LoggingStream (StringIO ):
97
+ def __init__ (self , logger : Logger , level : int = logging .NOTSET ) -> None :
98
+ super ().__init__ ()
99
+ self .logger = logger
100
+ self .level = level
101
+
102
+ def flush (self ) -> None :
103
+ super ().flush ()
104
+ value = self .getvalue ()
105
+ if value != "" :
106
+ self .logger .log (self .level , value )
107
+ self .truncate (0 )
108
+ self .seek (0 )
109
+
110
+ logging .captureWarnings (capture = True )
111
+ redirect_stderr (_LoggingStream (logger , level = logging .ERROR )).__enter__ ()
112
+ redirect_stdout (_LoggingStream (logger , level = logging .INFO )).__enter__ ()
113
+
14
114
15
115
## Handler utilities
16
116
@@ -21,14 +121,27 @@ def add_filter_to_handler(
21
121
worker_rank : int | None ,
22
122
log_level : int = logging .NOTSET ,
23
123
) -> None :
24
- def _filter (record : logging . LogRecord ) -> bool :
124
+ def _filter (record : WorkerLogRecord ) -> bool :
25
125
return (
26
- record .hostname == hostname # pyright: ignore[reportAttributeAccessIssue]
27
- and record .worker_rank == worker_rank # pyright: ignore[reportAttributeAccessIssue]
126
+ record .hostname == hostname
127
+ and record .worker_rank == worker_rank
28
128
and record .levelno >= log_level
29
129
)
30
130
31
- handler .addFilter (_filter )
131
+ handler .addFilter (_filter ) # pyright: ignore [reportArgumentType]
132
+
133
+
134
+ def stream_handler (hostname : str , rank : int | None , log_level : int = logging .NOTSET ) -> Handler :
135
+ handler = logging .StreamHandler ()
136
+ add_filter_to_handler (handler , hostname , rank , log_level = log_level )
137
+ handler .setFormatter (
138
+ logging .Formatter (
139
+ "%(asctime)s:%(levelname)s:%(hostname)s[%(worker_rank)s]: %(message)s"
140
+ if rank is not None
141
+ else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s" ,
142
+ ),
143
+ )
144
+ return handler
32
145
33
146
34
147
def file_handler (
@@ -67,19 +180,6 @@ def file_handlers(
67
180
return handlers
68
181
69
182
70
- def stream_handler (hostname : str , rank : int | None , log_level : int = logging .NOTSET ) -> Handler :
71
- handler = logging .StreamHandler ()
72
- add_filter_to_handler (handler , hostname , rank , log_level = log_level )
73
- handler .setFormatter (
74
- logging .Formatter (
75
- "%(asctime)s:%(levelname)s:%(hostname)s[%(worker_rank)s]: %(message)s"
76
- if rank is not None
77
- else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s" ,
78
- ),
79
- )
80
- return handler
81
-
82
-
83
183
def default_handlers (
84
184
hostnames : list [str ],
85
185
workers_per_host : list [int ],
@@ -91,86 +191,3 @@ def default_handlers(
91
191
stream_handler (hostname = hostnames [0 ], rank = 0 , log_level = log_level ),
92
192
* file_handlers (hostnames , workers_per_host , log_dir = log_dir , log_level = log_level ),
93
193
]
94
-
95
-
96
- ## Agent/worker utilities
97
-
98
-
99
- def log_records_to_socket (
100
- logger : Logger ,
101
- hostname : str ,
102
- worker_rank : int | None ,
103
- logger_hostname : str ,
104
- logger_port : int ,
105
- ) -> None :
106
- logger .setLevel (logging .NOTSET )
107
-
108
- old_factory = logging .getLogRecordFactory ()
109
-
110
- def record_factory (* args , ** kwargs ) -> logging .LogRecord : # noqa: ANN002, ANN003
111
- record = old_factory (* args , ** kwargs )
112
- record .hostname = hostname
113
- record .worker_rank = worker_rank
114
- return record
115
-
116
- logging .setLogRecordFactory (record_factory )
117
-
118
- logger .addHandler (SocketHandler (host = logger_hostname , port = logger_port ))
119
-
120
-
121
- def redirect_stdio_to_logger (logger : Logger ) -> None :
122
- class _LoggingStream (StringIO ):
123
- def __init__ (self , logger : Logger , level : int = logging .NOTSET ) -> None :
124
- super ().__init__ ()
125
- self .logger = logger
126
- self .level = level
127
-
128
- def flush (self ) -> None :
129
- super ().flush ()
130
- value = self .getvalue ()
131
- if value != "" :
132
- self .logger .log (self .level , value )
133
- self .truncate (0 )
134
- self .seek (0 )
135
-
136
- logging .captureWarnings (capture = True )
137
- redirect_stderr (_LoggingStream (logger , level = logging .ERROR )).__enter__ ()
138
- redirect_stdout (_LoggingStream (logger , level = logging .INFO )).__enter__ ()
139
-
140
-
141
- ## Launcher utilities
142
-
143
-
144
- class LogRecordSocketReceiver (ThreadingTCPServer ):
145
- def __init__ (self , host : str , port : int , handlers : list [Handler ]) -> None :
146
- self .host = host
147
- self .port = port
148
-
149
- class _LogRecordStreamHandler (StreamRequestHandler ):
150
- def handle (self ) -> None :
151
- while True :
152
- chunk_size = 4
153
- chunk = self .connection .recv (chunk_size )
154
- if len (chunk ) < chunk_size :
155
- break
156
- slen = struct .unpack (">L" , chunk )[0 ]
157
- chunk = self .connection .recv (slen )
158
- while len (chunk ) < slen :
159
- chunk = chunk + self .connection .recv (slen - len (chunk ))
160
- obj = pickle .loads (chunk )
161
- record = logging .makeLogRecord (obj )
162
-
163
- for handler in handlers :
164
- handler .handle (record )
165
-
166
- super ().__init__ (
167
- server_address = (host , port ),
168
- RequestHandlerClass = _LogRecordStreamHandler ,
169
- bind_and_activate = True ,
170
- )
171
- self .daemon_threads = True
172
-
173
- def shutdown (self ) -> None :
174
- """override BaseServer.shutdown() with added timeout"""
175
- self ._BaseServer__shutdown_request = True
176
- self ._BaseServer__is_shut_down .wait (timeout = 3 ) # pyright: ignore[reportAttributeAccessIssue]
0 commit comments