1
1
# Copyright (c) Facebook, Inc. and its affiliates.
2
+ import os
2
3
import pathlib
4
+ from pathlib import Path
3
5
from typing import Dict , List , Optional , Set
4
6
5
7
import torch .utils .data .gen_pyi as core_gen_pyi
6
- from torch .utils .data .gen_pyi import FileManager , get_method_definitions
8
+ from torch .utils .data .gen_pyi import gen_from_template , get_method_definitions
7
9
8
10
9
11
def get_lines_base_file (base_file_path : str , to_skip : Optional [Set [str ]] = None ):
@@ -18,14 +20,17 @@ def get_lines_base_file(base_file_path: str, to_skip: Optional[Set[str]] = None)
18
20
if skip_line in line :
19
21
skip_flag = True
20
22
if not skip_flag :
23
+ line = line .replace ("\n " , "" )
21
24
res .append (line )
22
25
return res
23
26
24
27
25
- def main () -> None :
28
+ def gen_pyi () -> None :
29
+ ROOT_DIR = Path (__file__ ).parent .resolve ()
30
+ print (f"Generating DataPipe Python interface file in { ROOT_DIR } " )
26
31
27
32
iter_init_base = get_lines_base_file (
28
- "iter/__init__.py" ,
33
+ os . path . join ( ROOT_DIR , "iter/__init__.py" ) ,
29
34
{"from torch.utils.data import IterDataPipe" , "# Copyright (c) Facebook, Inc. and its affiliates." },
30
35
)
31
36
@@ -69,14 +74,16 @@ def main() -> None:
69
74
70
75
iter_method_definitions = core_iter_method_definitions + td_iter_method_definitions
71
76
72
- fm = FileManager (install_dir = "." , template_dir = "." , dry_run = False )
73
- fm .write_with_template (
74
- filename = "iter/__init__.pyi" ,
75
- template_fn = "iter/__init__.pyi.in" ,
76
- env_callable = lambda : {"init_base" : iter_init_base , "IterDataPipeMethods" : iter_method_definitions },
77
+ replacements = [("${init_base}" , iter_init_base , 0 ), ("${IterDataPipeMethods}" , iter_method_definitions , 4 )]
78
+
79
+ gen_from_template (
80
+ dir = str (ROOT_DIR ),
81
+ template_name = "iter/__init__.pyi.in" ,
82
+ output_name = "iter/__init__.pyi" ,
83
+ replacements = replacements ,
77
84
)
78
85
# TODO: Add map_method_definitions when there are MapDataPipes defined in this library
79
86
80
87
81
88
if __name__ == "__main__" :
82
- main () # TODO: Run this script automatically within the build and CI process
89
+ gen_pyi ()
0 commit comments