Skip to content

Commit ce8480c

Browse files
committed
change logging structure
1 parent e47c02d commit ce8480c

File tree

3 files changed

+71
-24
lines changed

3 files changed

+71
-24
lines changed

src/torchrunx/agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def entrypoint(serialized_worker_args: bytes):
8080
logger.setLevel(logging.DEBUG)
8181
logger.addHandler(socketHandler)
8282
logger.debug("creating TCPStore for worker group.")
83+
8384
store = dist.TCPStore( # pyright: ignore[reportPrivateImportUsage]
8485
host_name=worker_args.master_hostname,
8586
port=worker_args.master_port,

src/torchrunx/launcher.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
LauncherAgentGroup,
2828
LauncherPayload,
2929
LogRecordSocketReceiver,
30+
default_logging,
3031
get_open_port,
3132
)
3233

@@ -71,7 +72,7 @@ class Launcher:
7172
ssh_config_file: str | os.PathLike | None = None
7273
backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None
7374
log_dir: os.PathLike | str = "./logs"
74-
propagate_logs: bool = True
75+
log_spec: dict[str, list[logging.Handler]] | None = None
7576
env_vars: list[str] = field(
7677
default_factory=lambda: [
7778
"PATH",
@@ -106,16 +107,26 @@ def run(
106107

107108
logger = logging.getLogger("torchrunx")
108109
logger.setLevel(logging.DEBUG)
109-
logger.propagate = self.propagate_logs
110+
logger.propagate = False
110111

111112
log_dir = Path(self.log_dir)
112113
log_dir.mkdir(parents=True, exist_ok=True)
113-
timestamp = datetime.datetime.now().isoformat(timespec="seconds")
114-
115-
log_file_formatter = logging.Formatter("%(asctime)s:%(levelname)s:%(name)s:%(message)s")
116-
log_file_handler = logging.FileHandler(f"{log_dir}/{timestamp}.log")
117-
log_file_handler.setFormatter(log_file_formatter)
118-
logger.addHandler(log_file_handler)
114+
#timestamp = datetime.datetime.now().isoformat(timespec="seconds")
115+
116+
if self.log_spec is None:
117+
# TODO: this assumes the type of workers_per_host is simply int. We should consider
118+
# again whether it's worth supporting inhomogeneous allocations (list[int])
119+
self.log_spec = default_logging(num_agents=len(self.hostnames),
120+
num_workers=self.workers_per_host, # type: ignore
121+
log_dir=os.fspath(log_dir))
122+
123+
log_formatter = logging.Formatter("%(asctime)s:%(levelname)s:%(name)s:%(message)s")
124+
125+
for lname, handlers in self.log_spec.items(): # type: ignore
126+
_logger = logging.getLogger(f"torchrunx.{lname}")
127+
for handler in handlers:
128+
handler.setFormatter(log_formatter)
129+
_logger.addHandler(handler)
119130

120131
log_process = Process(target=monitor_log, args=(), daemon=True)
121132
log_process.start()
@@ -194,7 +205,7 @@ def run(
194205

195206
worker_log_names = [
196207
[
197-
f"torchrunx.agent-{i}.worker-{local_rank}"
208+
f"torchrunx.agent-{i}-worker-{local_rank}"
198209
for local_rank in range(workers_per_host[i]) # type: ignore
199210
]
200211
for i in range(len(self.hostnames))
@@ -263,7 +274,7 @@ def launch(
263274
ssh_config_file: str | os.PathLike | None = None,
264275
backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None,
265276
log_dir: os.PathLike | str = "./logs",
266-
propagate_logs: bool = True,
277+
log_spec: dict[str, list[logging.Handler]] | None = None,
267278
env_vars: list[str] = [
268279
"PATH",
269280
"LD_LIBRARY",
@@ -294,8 +305,8 @@ def launch(
294305
:type backend: Literal['mpi', 'gloo', 'nccl', 'ucc', None], optional
295306
:param log_dir: A directory in which logs should be written, defaults to "./logs"
296307
:type log_dir: os.PathLike | str, optional
297-
:param log_level: The logging level, defaults to logging.WARN
298-
:type log_level: logging._Level, optional
308+
:param log_spec: TODO
309+
:type log_spec: TODO
299310
:param env_vars: A list of environmental variables to be copied from the launcher environment to workers. Allows for bash pattern matching syntax, defaults to ["PATH", "LD_LIBRARY", "LIBRARY_PATH", "PYTHON*", "CUDA*", "TORCH*", "PYTORCH*", "NCCL*"]
300311
:type env_vars: list[str], optional
301312
:param env_file: An additional environment file that will be sourced prior to executing ``func``, defaults to None
@@ -312,7 +323,7 @@ def launch(
312323
ssh_config_file=ssh_config_file,
313324
backend=backend,
314325
log_dir=log_dir,
315-
propagate_logs=propagate_logs,
326+
log_spec=log_spec,
316327
env_vars=env_vars,
317328
env_file=env_file,
318329
timeout=timeout,

src/torchrunx/utils.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def sync_payloads(
118118

119119
def sync_agent_statuses(self, status: AgentStatus) -> list[AgentStatus]:
120120
return self._all_gather(object=status)[1:]
121-
121+
122+
122123
class LogRecordStreamHandler(socketserver.StreamRequestHandler):
123124
"""Handler for a streaming logging request.
124125
@@ -136,7 +137,7 @@ def handle(self):
136137
chunk = self.connection.recv(4)
137138
if len(chunk) < 4:
138139
break
139-
slen = struct.unpack('>L', chunk)[0]
140+
slen = struct.unpack(">L", chunk)[0]
140141
chunk = self.connection.recv(slen)
141142
while len(chunk) < slen:
142143
chunk = chunk + self.connection.recv(slen - len(chunk))
@@ -150,7 +151,7 @@ def unPickle(self, data):
150151
def handleLogRecord(self, record):
151152
# if a name is specified, we use the named logger rather than the one
152153
# implied by the record.
153-
if self.server.logname is not None: # type: ignore
154+
if self.server.logname is not None: # type: ignore
154155
name = self.server.logname # type: ignore
155156
else:
156157
name = record.name
@@ -162,28 +163,62 @@ def handleLogRecord(self, record):
162163
if logger.getEffectiveLevel() <= record.levelno:
163164
logger.handle(record)
164165

166+
165167
class LogRecordSocketReceiver(socketserver.ThreadingTCPServer):
166168
"""
167169
Simple TCP socket-based logging receiver suitable for testing.
168170
"""
169171

170-
allow_reuse_address = 1 # type: ignore
172+
allow_reuse_address = 1 # type: ignore
171173

172-
def __init__(self, host='localhost',
173-
port=logging.handlers.DEFAULT_TCP_LOGGING_PORT,
174-
handler=LogRecordStreamHandler):
174+
def __init__(
175+
self,
176+
host="localhost",
177+
port=logging.handlers.DEFAULT_TCP_LOGGING_PORT,
178+
handler=LogRecordStreamHandler,
179+
):
175180
socketserver.ThreadingTCPServer.__init__(self, (host, port), handler)
176181
self.abort = 0
177182
self.timeout = 1
178183
self.logname = None
179-
180184

181185
def serve_until_stopped(self):
182186
abort = 0
183187
while not abort:
184-
rd, wr, ex = select.select([self.socket.fileno()],
185-
[], [],
186-
self.timeout)
188+
rd, wr, ex = select.select([self.socket.fileno()], [], [], self.timeout)
187189
if rd:
188190
self.handle_request()
189191
abort = self.abort
192+
193+
194+
def default_logging(
195+
num_agents: int, num_workers: int, log_dir: str
196+
) -> dict[str, list[logging.Handler]]:
197+
"""
198+
Generates torchrunx's default
199+
200+
:param num_agents: Number of agents in work group
201+
:type num_agents: int
202+
:param num_workers: Number of workers per agent
203+
:type num_workers: int
204+
:return: A logging structure to be passed to :mod:`torchrunx.launch` as the ``log_spec`` argument
205+
:rtype: dict[str, list[logging.Handler]]
206+
"""
207+
208+
timestamp = datetime.datetime.now().isoformat(timespec="seconds")
209+
210+
agents: dict[str, list[logging.Handler]] = {
211+
f"agent-{i}": [logging.FileHandler(f"{log_dir}/{timestamp}-agent-{i}.log")]
212+
for i in range(num_agents)
213+
}
214+
workers: dict[str, list[logging.Handler]] = {
215+
f"agent-{i}-worker-{j}": [
216+
logging.FileHandler(f"{log_dir}/{timestamp}-agent-{i}.worker-{j}.log")
217+
]
218+
for j in range(num_workers)
219+
for i in range(num_agents)
220+
}
221+
222+
workers["agent-0-worker-0"].append(logging.StreamHandler())
223+
224+
return {**agents, **workers}

0 commit comments

Comments
 (0)