Skip to content

Commit e28b74a

Browse files
committed
Automatically generate datapipe.pyi in setup.py
ghstack-source-id: 56c0295 Pull Request resolved: #290
1 parent 8992d09 commit e28b74a

File tree

5 files changed

+21
-373
lines changed

5 files changed

+21
-373
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ dist/*
33
torchdata.egg-info/*
44

55
torchdata/version.py
6+
torchdata/datapipes/iter/__init__.pyi
67

78
# Editor temporaries
89
*.swn

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from pathlib import Path
1010

1111
from setuptools import find_packages, setup
12+
from torchdata.datapipes.gen_pyi import gen_pyi
13+
1214

1315
ROOT_DIR = Path(__file__).parent.resolve()
1416

@@ -110,3 +112,4 @@ def get_parser():
110112
packages=find_packages(exclude=["test*", "examples*"]),
111113
zip_safe=False,
112114
)
115+
gen_pyi()

torchdata/datapipes/gen_pyi.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import os
23
import pathlib
4+
from pathlib import Path
35
from typing import Dict, List, Optional, Set
46

57
import torch.utils.data.gen_pyi as core_gen_pyi
6-
from torch.utils.data.gen_pyi import FileManager, get_method_definitions
8+
from torch.utils.data.gen_pyi import gen_from_template, get_method_definitions
79

810

911
def get_lines_base_file(base_file_path: str, to_skip: Optional[Set[str]] = None):
@@ -18,14 +20,17 @@ def get_lines_base_file(base_file_path: str, to_skip: Optional[Set[str]] = None)
1820
if skip_line in line:
1921
skip_flag = True
2022
if not skip_flag:
23+
line = line.replace("\n", "")
2124
res.append(line)
2225
return res
2326

2427

25-
def main() -> None:
28+
def gen_pyi() -> None:
29+
ROOT_DIR = Path(__file__).parent.resolve()
30+
print(f"Generating DataPipe Python interface file in {ROOT_DIR}")
2631

2732
iter_init_base = get_lines_base_file(
28-
"iter/__init__.py",
33+
os.path.join(ROOT_DIR, "iter/__init__.py"),
2934
{"from torch.utils.data import IterDataPipe", "# Copyright (c) Facebook, Inc. and its affiliates."},
3035
)
3136

@@ -69,14 +74,16 @@ def main() -> None:
6974

7075
iter_method_definitions = core_iter_method_definitions + td_iter_method_definitions
7176

72-
fm = FileManager(install_dir=".", template_dir=".", dry_run=False)
73-
fm.write_with_template(
74-
filename="iter/__init__.pyi",
75-
template_fn="iter/__init__.pyi.in",
76-
env_callable=lambda: {"init_base": iter_init_base, "IterDataPipeMethods": iter_method_definitions},
77+
replacements = [("${init_base}", iter_init_base, 0), ("${IterDataPipeMethods}", iter_method_definitions, 4)]
78+
79+
gen_from_template(
80+
dir=str(ROOT_DIR),
81+
template_name="iter/__init__.pyi.in",
82+
output_name="iter/__init__.pyi",
83+
replacements=replacements,
7784
)
7885
# TODO: Add map_method_definitions when there are MapDataPipes defined in this library
7986

8087

8188
if __name__ == "__main__":
82-
main() # TODO: Run this script automatically within the build and CI process
89+
gen_pyi()

0 commit comments

Comments
 (0)