diff --git a/stub_uploader/build_wheel.py b/stub_uploader/build_wheel.py index ccedb67f..d623fbd7 100644 --- a/stub_uploader/build_wheel.py +++ b/stub_uploader/build_wheel.py @@ -135,21 +135,71 @@ def __init__(self, base_path: Path, package_data: dict[str, list[str]]) -> None: self.base_path = base_path self.package_data = package_data + def package_path(self, package_name: str) -> Path: + """Return the path of a given package name. + + The package name can use dotted notation to address sub-packages. + The top-level package name can optionally include the "-stubs" suffix. + """ + top_level, *sub_packages = package_name.split(".") + if top_level.endswith(SUFFIX): + top_level = top_level[: -len(SUFFIX)] + return self.base_path.joinpath(top_level, *sub_packages) + + def is_single_file_package(self, package_name: str) -> bool: + filename = package_name.split("-")[0] + ".pyi" + return (self.base_path / filename).exists() + @property def top_level_packages(self) -> list[str]: """Top level package names. - These are the packages that are not subpackages of any other package - and includes namespace packages. + These are the packages that are not sub-packages of any other package + and includes namespace packages. Their name includes the "-stubs" + suffix. """ return list(self.package_data.keys()) + @property + def top_level_non_namespace_packages(self) -> list[str]: + """Top level non-namespace package names. + + This will return all packages that are not subpackages of any other + package, other than namespace packages in dotted notation, e.g. if + "flying" is a top level namespace package, and "circus" is a + non-namespace sub-package, this will return ["flying.circus"]. + """ + packages: list[str] = [] + for top_level in self.top_level_packages: + if self.is_single_file_package(top_level): + packages.append(top_level) + else: + packages.extend(self._find_non_namespace_sub_packages(top_level)) + return packages + + def _find_non_namespace_sub_packages(self, package: str) -> list[str]: + path = self.package_path(package) + if is_namespace_package(path): + sub_packages: list[str] = [] + for entry in path.iterdir(): + if entry.is_dir(): + sub_name = package + "." + entry.name + sub_packages.extend(self._find_non_namespace_sub_packages(sub_name)) + return sub_packages + else: + return [package] + def add_file(self, package: str, filename: str, file_contents: str) -> None: """Add a file to a package.""" - entry_path = self.base_path / package + top_level = package.split(".")[0] + entry_path = self.package_path(package) entry_path.mkdir(exist_ok=True) (entry_path / filename).write_text(file_contents) - self.package_data[package].append(filename) + self.package_data[top_level].append(filename) + + +def is_namespace_package(path: Path) -> bool: + return not (path / "__init__.pyi").exists() def find_stub_files(top: str) -> list[str]: @@ -166,6 +216,8 @@ def find_stub_files(top: str) -> list[str]: name.isidentifier() ), "All file names must be valid Python modules" result.append(os.path.relpath(os.path.join(root, file), top)) + elif file == "py.typed": + result.append(os.path.relpath(os.path.join(root, file), top)) elif not file.endswith((".md", ".rst")): # Allow having README docs, as some stubs have these (e.g. click). if ( @@ -257,7 +309,7 @@ def collect_package_data(base_path: Path) -> PackageData: def add_partial_markers(pkg_data: PackageData) -> None: - for package in pkg_data.top_level_packages: + for package in pkg_data.top_level_non_namespace_packages: pkg_data.add_file(package, "py.typed", "partial\n") diff --git a/tests/test_integration.py b/tests/test_integration.py index 03c302b4..c5f2a90b 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -28,6 +28,7 @@ from stub_uploader.ts_data import read_typeshed_data TYPESHED = "../typeshed" +THIRD_PARTY_PATH = Path(TYPESHED) / THIRD_PARTY_NAMESPACE def test_fetch_pypi_versions() -> None: @@ -37,9 +38,7 @@ def test_fetch_pypi_versions() -> None: assert not get_version.fetch_pypi_versions("types-nonexistent-distribution") -@pytest.mark.parametrize( - "distribution", os.listdir(os.path.join(TYPESHED, THIRD_PARTY_NAMESPACE)) -) +@pytest.mark.parametrize("distribution", os.listdir(THIRD_PARTY_PATH)) def test_build_wheel(distribution: str) -> None: """Check that we can build wheels for all distributions.""" tmp_dir = build_wheel.main(TYPESHED, distribution, version="1.1.1") @@ -47,9 +46,7 @@ def test_build_wheel(distribution: str) -> None: assert list(os.listdir(tmp_dir)) # check it is not empty -@pytest.mark.parametrize( - "distribution", os.listdir(os.path.join(TYPESHED, THIRD_PARTY_NAMESPACE)) -) +@pytest.mark.parametrize("distribution", os.listdir(THIRD_PARTY_PATH)) def test_version_increment(distribution: str) -> None: get_version.determine_stub_version(read_metadata(TYPESHED, distribution)) @@ -145,9 +142,7 @@ def test_dependency_order_single() -> None: ] -@pytest.mark.parametrize( - "distribution", os.listdir(os.path.join(TYPESHED, THIRD_PARTY_NAMESPACE)) -) +@pytest.mark.parametrize("distribution", os.listdir(THIRD_PARTY_PATH)) def test_recursive_verify(distribution: str) -> None: recursive_verify(read_metadata(TYPESHED, distribution), TYPESHED) @@ -170,3 +165,35 @@ def test_verify_requires_python() -> None: InvalidRequires, match="Expected requires_python to be a '>=' specifier" ): verify_requires_python("==3.10") + + +@pytest.mark.parametrize( + "distribution,expected_packages", + [ + ("pytz", ["pytz-stubs"]), + ("Pillow", ["PIL-stubs"]), + ("protobuf", ["google-stubs"]), + ("google-cloud-ndb", ["google-stubs"]), + ], +) +def test_pkg_data_top_level_packages( + distribution: str, expected_packages: list[str] +) -> None: + pkg_data = build_wheel.collect_package_data(THIRD_PARTY_PATH / distribution) + assert pkg_data.top_level_packages == expected_packages + + +@pytest.mark.parametrize( + "distribution,expected_packages", + [ + ("pytz", ["pytz-stubs"]), + ("Pillow", ["PIL-stubs"]), + ("protobuf", ["google-stubs.protobuf"]), + ("google-cloud-ndb", ["google-stubs.cloud.ndb"]), + ], +) +def test_pkg_data_non_namespace_packages( + distribution: str, expected_packages: list[str] +) -> None: + pkg_data = build_wheel.collect_package_data(THIRD_PARTY_PATH / distribution) + assert pkg_data.top_level_non_namespace_packages == expected_packages