Skip to content

Commit 060850b

Browse files
authored
Merge pull request #61 from apoorvkh/misc-refactoring
Misc refactoring
2 parents 81f5e91 + 9c97d09 commit 060850b

File tree

13 files changed

+399
-304
lines changed

13 files changed

+399
-304
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,4 @@ jobs:
8686
cache: false
8787
environments: default
8888
activate-environment: default
89-
- run: pytest tests/test_CI.py
89+
- run: pytest tests/test_ci.py

pixi.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "torchrunx"
7-
version = "0.1.3"
7+
version = "0.2.0"
88
authors = [
99
{name = "Apoorv Khandelwal", email = "[email protected]"},
1010
{name = "Peter Curtin", email = "[email protected]"},
@@ -41,7 +41,24 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
4141
line-length = 100
4242
src = ["src", "tests"]
4343
[tool.ruff.lint]
44-
select = ["E", "F", "B", "UP", "I"]
44+
select = ["ALL"]
45+
ignore = [
46+
"D", # documentation
47+
"ANN101", "ANN102", "ANN401", # self / cls / Any annotations
48+
"BLE001", # blind exceptions
49+
"TD", # todo syntax
50+
"FIX002", # existing todos
51+
"PLR0913", # too many arguments
52+
"DTZ005", # datetime timezone
53+
"S301", # bandit: pickle
54+
"S603", "S607", # bandit: subprocess
55+
"COM812", "ISC001", # conflict with formatter
56+
]
57+
[tool.ruff.lint.per-file-ignores]
58+
"tests/**/*.py" = [
59+
"S101", # allow asserts
60+
"T201" # allow prints
61+
]
4562

4663
[tool.pyright]
4764
include = ["src", "tests"]

src/torchrunx/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from .launcher import Launcher, launch
2+
from .logging_utils import add_filter_to_handler, file_handler, stream_handler
23

34
__all__ = [
45
"Launcher",
56
"launch",
7+
"add_filter_to_handler",
8+
"file_handler",
9+
"stream_handler",
610
]

src/torchrunx/agent.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
import socket
77
import sys
88
import tempfile
9+
import traceback
910
from dataclasses import dataclass
1011
from typing import Any, Callable, Literal
1112

1213
import cloudpickle
1314
import torch
1415
import torch.distributed as dist
15-
from torch.distributed.elastic.multiprocessing import start_processes
16-
from typing_extensions import Self
16+
import torch.distributed.elastic.multiprocessing as dist_mp
1717

1818
from .logging_utils import log_records_to_socket, redirect_stdio_to_logger
1919
from .utils import (
@@ -40,16 +40,20 @@ class WorkerArgs:
4040
hostname: str
4141
timeout: int
4242

43-
def to_bytes(self) -> bytes:
44-
return cloudpickle.dumps(self)
43+
def serialize(self) -> SerializedWorkerArgs:
44+
return SerializedWorkerArgs(worker_args=self)
4545

46-
@classmethod
47-
def from_bytes(cls, serialized: bytes) -> Self:
48-
return cloudpickle.loads(serialized)
4946

47+
class SerializedWorkerArgs:
48+
def __init__(self, worker_args: WorkerArgs) -> None:
49+
self.bytes = cloudpickle.dumps(worker_args)
5050

51-
def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException:
52-
worker_args = WorkerArgs.from_bytes(serialized_worker_args)
51+
def deserialize(self) -> WorkerArgs:
52+
return cloudpickle.loads(self.bytes)
53+
54+
55+
def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerException:
56+
worker_args: WorkerArgs = serialized_worker_args.deserialize()
5357

5458
logger = logging.getLogger()
5559

@@ -63,18 +67,14 @@ def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException:
6367

6468
redirect_stdio_to_logger(logger)
6569

66-
store = dist.TCPStore( # pyright: ignore[reportPrivateImportUsage]
70+
store = dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
6771
host_name=worker_args.main_agent_hostname,
6872
port=worker_args.main_agent_port,
6973
world_size=worker_args.world_size,
7074
is_master=(worker_args.rank == 0),
7175
)
7276

73-
backend = worker_args.backend
74-
if backend is None:
75-
backend = "nccl" if torch.cuda.is_available() else "gloo"
76-
77-
logger.debug(f"using backend: {backend}")
77+
backend = worker_args.backend or ("nccl" if torch.cuda.is_available() else "gloo")
7878

7979
dist.init_process_group(
8080
backend=backend,
@@ -91,19 +91,17 @@ def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException:
9191
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
9292
os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)
9393

94-
logger.debug(f"executing function: {worker_args.function}")
95-
9694
try:
9795
return worker_args.function()
9896
except Exception as e:
99-
logger.error(e)
97+
traceback.print_exc()
10098
return WorkerException(exception=e)
10199
finally:
102100
sys.stdout.flush()
103101
sys.stderr.flush()
104102

105103

106-
def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int):
104+
def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int) -> None:
107105
agent_rank = launcher_agent_group.rank - 1
108106

109107
payload = AgentPayload(
@@ -132,16 +130,9 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
132130

133131
redirect_stdio_to_logger(logger)
134132

135-
if torch.__version__ >= "2.3":
136-
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs
137-
138-
log_kwargs = {"logs_specs": DefaultLogsSpecs(log_dir=tempfile.mkdtemp())}
139-
else:
140-
log_kwargs = {"log_dir": tempfile.mkdtemp()}
141-
142133
# spawn workers
143134

144-
ctx = start_processes(
135+
ctx = dist_mp.start_processes(
145136
name=f"{hostname}_",
146137
entrypoint=entrypoint,
147138
args={
@@ -159,31 +150,30 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
159150
world_size=worker_world_size,
160151
hostname=launcher_payload.hostnames[agent_rank],
161152
timeout=launcher_payload.timeout,
162-
).to_bytes(),
153+
).serialize(),
163154
)
164155
for i in range(num_workers)
165156
},
166157
envs={i: {} for i in range(num_workers)},
167-
**log_kwargs, # pyright: ignore [reportArgumentType]
158+
**(
159+
{"logs_specs": dist_mp.DefaultLogsSpecs(log_dir=tempfile.mkdtemp())}
160+
if torch.__version__ >= "2.3"
161+
else {"log_dir": tempfile.mkdtemp()}
162+
), # pyright: ignore [reportArgumentType]
168163
)
169-
logger.info("starting processes")
170164

171165
try:
172166
status = None
173167
while True:
174168
if status is None or status.state == "running":
175-
status = AgentStatus.from_result(
176-
result=ctx.wait(5), worker_global_ranks=worker_global_ranks
177-
)
169+
status = AgentStatus.from_result(ctx.wait(5))
178170

179171
agent_statuses = launcher_agent_group.sync_agent_statuses(status=status)
180172

181-
if all(s.state == "done" for s in agent_statuses):
182-
break
183-
elif any(s.state == "failed" for s in agent_statuses):
173+
all_done = all(s.state == "done" for s in agent_statuses)
174+
any_failed = any(s.state == "failed" for s in agent_statuses)
175+
if all_done or any_failed:
184176
break
185-
except:
186-
raise
187177
finally:
188178
ctx.close()
189179
sys.stdout.flush()

src/torchrunx/environment.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ def slurm_hosts() -> list[str]:
1717
:rtype: list[str]
1818
"""
1919
# TODO: sanity check SLURM variables, commands
20-
assert in_slurm_job()
20+
if not in_slurm_job():
21+
msg = "Not in a SLURM job"
22+
raise RuntimeError(msg)
2123
return (
2224
subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]])
2325
.decode()
@@ -35,15 +37,18 @@ def slurm_workers() -> int:
3537
:rtype: int
3638
"""
3739
# TODO: sanity check SLURM variables, commands
38-
assert in_slurm_job()
40+
if not in_slurm_job():
41+
msg = "Not in a SLURM job"
42+
raise RuntimeError(msg)
43+
3944
if "SLURM_JOB_GPUS" in os.environ:
4045
# TODO: is it possible to allocate uneven GPUs across nodes?
4146
return len(os.environ["SLURM_JOB_GPUS"].split(","))
42-
elif "SLURM_GPUS_PER_NODE" in os.environ:
47+
if "SLURM_GPUS_PER_NODE" in os.environ:
4348
return int(os.environ["SLURM_GPUS_PER_NODE"])
44-
else:
45-
# TODO: should we assume that we plan to do one worker per CPU?
46-
return int(os.environ["SLURM_CPUS_ON_NODE"])
49+
50+
# TODO: should we assume that we plan to do one worker per CPU?
51+
return int(os.environ["SLURM_CPUS_ON_NODE"])
4752

4853

4954
def auto_hosts() -> list[str]:

0 commit comments

Comments
 (0)