Skip to content

Commit acb92dd

Browse files
fix(sql): Add fallback to source_defined_primary_key in CatalogProvider (#627)
Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: AJ Steers <[email protected]>
1 parent e24160e commit acb92dd

File tree

4 files changed

+114
-8
lines changed

4 files changed

+114
-8
lines changed

airbyte_cdk/sql/shared/catalog_providers.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,21 @@ def get_stream_properties(
119119
def get_primary_keys(
120120
self,
121121
stream_name: str,
122-
) -> list[str]:
123-
"""Return the primary keys for the given stream."""
124-
pks = self.get_configured_stream_info(stream_name).primary_key
122+
) -> list[str] | None:
123+
"""Return the primary key column names for the given stream.
124+
125+
We return `source_defined_primary_key` if set, or `primary_key` otherwise. If both are set,
126+
we assume they should not should differ, since Airbyte data integrity constraints do not
127+
permit overruling a source's pre-defined primary keys. If neither is set, we return `None`.
128+
129+
Returns:
130+
A list of column names that constitute the primary key, or None if no primary key is defined.
131+
"""
132+
configured_stream = self.get_configured_stream_info(stream_name)
133+
pks = configured_stream.stream.source_defined_primary_key or configured_stream.primary_key
134+
125135
if not pks:
126-
return []
136+
return None
127137

128138
normalized_pks: list[list[str]] = [
129139
[LowerCaseNormalizer.normalize(c) for c in pk] for pk in pks

airbyte_cdk/sql/shared/sql_processor.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -666,9 +666,13 @@ def _merge_temp_table_to_final_table(
666666
"""
667667
nl = "\n"
668668
columns = {self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)}
669-
pk_columns = {
670-
self._quote_identifier(c) for c in self.catalog_provider.get_primary_keys(stream_name)
671-
}
669+
primary_keys = self.catalog_provider.get_primary_keys(stream_name)
670+
if not primary_keys:
671+
raise exc.AirbyteInternalError(
672+
message="Cannot merge tables without primary keys. Primary keys are required for merge operations.",
673+
context={"stream_name": stream_name},
674+
)
675+
pk_columns = {self._quote_identifier(c) for c in primary_keys}
672676
non_pk_columns = columns - pk_columns
673677
join_clause = f"{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns)
674678
set_clause = f"{nl} , ".join(f"{col} = tmp.{col}" for col in non_pk_columns)
@@ -725,6 +729,11 @@ def _emulated_merge_temp_table_to_final_table(
725729
final_table = self._get_table_by_name(final_table_name)
726730
temp_table = self._get_table_by_name(temp_table_name)
727731
pk_columns = self.catalog_provider.get_primary_keys(stream_name)
732+
if not pk_columns:
733+
raise exc.AirbyteInternalError(
734+
message="Cannot merge tables without primary keys. Primary keys are required for merge operations.",
735+
context={"stream_name": stream_name},
736+
)
728737

729738
columns_to_update: set[str] = self._get_sql_column_definitions(
730739
stream_name=stream_name

airbyte_cdk/test/standard_tests/connector_base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,16 @@ def connector(cls) -> type[IConnector] | Callable[[], IConnector] | None:
5959
try:
6060
module = importlib.import_module(expected_module_name)
6161
except ModuleNotFoundError as e:
62-
raise ImportError(f"Could not import module '{expected_module_name}'.") from e
62+
raise ImportError(
63+
f"Could not import module '{expected_module_name}'. "
64+
"Please ensure you are running from within the connector's virtual environment, "
65+
"for instance by running `poetry run airbyte-cdk connector test` from the "
66+
"connector directory. If the issue persists, check that the connector "
67+
f"module matches the expected module name '{expected_module_name}' and that the "
68+
f"connector class matches the expected class name '{expected_class_name}'. "
69+
"Alternatively, you can run `airbyte-cdk image test` to run a subset of tests "
70+
"against the connector's image."
71+
) from e
6372
finally:
6473
# Change back to the original working directory
6574
os.chdir(cwd_snapshot)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from unittest.mock import Mock
2+
3+
import pytest
4+
5+
from airbyte_cdk.models import AirbyteStream, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream
6+
from airbyte_cdk.sql.shared.catalog_providers import CatalogProvider
7+
8+
9+
class TestCatalogProvider:
10+
"""Test cases for CatalogProvider.get_primary_keys() method."""
11+
12+
@pytest.mark.parametrize(
13+
"configured_primary_key,source_defined_primary_key,expected_result,test_description",
14+
[
15+
(["configured_id"], ["source_id"], ["source_id"], "prioritizes source when both set"),
16+
([], ["source_id"], ["source_id"], "uses source when configured empty"),
17+
(None, ["source_id"], ["source_id"], "uses source when configured None"),
18+
(
19+
["configured_id"],
20+
[],
21+
["configured_id"],
22+
"falls back to configured when source empty",
23+
),
24+
(
25+
["configured_id"],
26+
None,
27+
["configured_id"],
28+
"falls back to configured when source None",
29+
),
30+
([], [], None, "returns None when both empty"),
31+
(None, None, None, "returns None when both None"),
32+
([], ["id1", "id2"], ["id1", "id2"], "handles composite keys from source"),
33+
],
34+
)
35+
def test_get_primary_keys_parametrized(
36+
self, configured_primary_key, source_defined_primary_key, expected_result, test_description
37+
):
38+
"""Test primary key fallback logic with various input combinations."""
39+
configured_pk_wrapped = (
40+
None
41+
if configured_primary_key is None
42+
else [[pk] for pk in configured_primary_key]
43+
if configured_primary_key
44+
else []
45+
)
46+
source_pk_wrapped = (
47+
None
48+
if source_defined_primary_key is None
49+
else [[pk] for pk in source_defined_primary_key]
50+
if source_defined_primary_key
51+
else []
52+
)
53+
54+
stream = AirbyteStream(
55+
name="test_stream",
56+
json_schema={
57+
"type": "object",
58+
"properties": {
59+
"id": {"type": "string"},
60+
"id1": {"type": "string"},
61+
"id2": {"type": "string"},
62+
},
63+
},
64+
supported_sync_modes=["full_refresh"],
65+
source_defined_primary_key=source_pk_wrapped,
66+
)
67+
configured_stream = ConfiguredAirbyteStream(
68+
stream=stream,
69+
sync_mode="full_refresh",
70+
destination_sync_mode="overwrite",
71+
primary_key=configured_pk_wrapped,
72+
)
73+
catalog = ConfiguredAirbyteCatalog(streams=[configured_stream])
74+
75+
provider = CatalogProvider(catalog)
76+
result = provider.get_primary_keys("test_stream")
77+
78+
assert result == expected_result

0 commit comments

Comments
 (0)