Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions fastcrud/crud/fast_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,15 @@ def _prepare_and_apply_joins(
if joined_model_filters:
stmt = stmt.filter(*joined_model_filters)

if join.sort_columns:
for idx, column_name in enumerate(join.sort_columns):
column = getattr(model, column_name, None)
if not column:
raise ArgumentError(f"Invalid column name: {column_name}")

order = join.sort_orders[idx] if join.sort_orders else "asc"
stmt = stmt.order_by(asc(column) if order == "asc" else desc(column))

return stmt

async def create(
Expand Down
22 changes: 22 additions & 0 deletions fastcrud/crud/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class JoinConfig(BaseModel):
alias: Optional[AliasedClass] = None
filters: Optional[dict] = None
relationship_type: Optional[str] = "one-to-one"
sort_columns: Optional[Union[str, list[str]]] = None
sort_orders: Optional[Union[str, list[str]]] = None

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand All @@ -37,6 +39,26 @@ def check_valid_join_type(cls, value):
raise ValueError(f"Unsupported join type: {value}")
return value

@field_validator("sort_columns")
def check_valid_sort_columns(cls, value):
if value is not None and not isinstance(value, (str, list)):
raise ValueError("sort_columns must be a string or a list of strings")
return value

@field_validator("sort_orders")
def check_valid_sort_orders(cls, value):
if value is not None:
if isinstance(value, str):
if value not in ["asc", "desc"]:
raise ValueError("Invalid sort order: {value}. Only 'asc' or 'desc' are allowed.")
elif isinstance(value, list):
for order in value:
if order not in ["asc", "desc"]:
raise ValueError("Invalid sort order: {order}. Only 'asc' or 'desc' are allowed.")
else:
raise ValueError("sort_orders must be a string or a list of strings")
return value


def _extract_matching_columns_from_schema(
model: Union[ModelType, AliasedClass],
Expand Down
182 changes: 182 additions & 0 deletions tests/sqlalchemy/crud/test_get_multi_joined.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,3 +1359,185 @@ async def test_get_multi_joined_explicit_join_preserves_condition(async_session)
assert (
task3["department"]["name"] == "Engineering"
), "Task 3 should be in Engineering department"


@pytest.mark.asyncio
async def test_get_multi_joined_sorting_nested_items_one_to_many(async_session):
cards = [
Card(title="Card A"),
Card(title="Card B"),
Card(title="Card C"),
]
async_session.add_all(cards)
await async_session.flush()

articles = [
Article(title="Article 3", card_id=cards[0].id),
Article(title="Article 1", card_id=cards[0].id),
Article(title="Article 2", card_id=cards[0].id),
Article(title="Article 2", card_id=cards[1].id),
Article(title="Article 1", card_id=cards[1].id),
]
async_session.add_all(articles)
await async_session.commit()

card_crud = FastCRUD(Card)

result = await card_crud.get_multi_joined(
db=async_session,
nest_joins=True,
joins_config=[
JoinConfig(
model=Article,
join_on=Article.card_id == Card.id,
join_prefix="articles_",
join_type="left",
relationship_type="one-to-many",
sort_columns=["title"],
sort_orders=["asc"],
)
],
)

assert result is not None, "No data returned from the database."
assert "data" in result, "Result should contain 'data' key."
data = result["data"]
assert isinstance(data, list), "Result data should be a list."
assert len(data) == 3, "Expected three card records."

card_a = next((c for c in data if c["id"] == cards[0].id), None)
card_b = next((c for c in data if c["id"] == cards[1].id), None)
card_c = next((c for c in data if c["id"] == cards[2].id), None)

assert (
card_a is not None and "articles" in card_a
), "Card A should have nested articles."
assert len(card_a["articles"]) == 3, "Card A should have three articles."
assert (
card_a["articles"][0]["title"] == "Article 1"
), "Card A's first article title should be 'Article 1'."
assert (
card_a["articles"][1]["title"] == "Article 2"
), "Card A's second article title should be 'Article 2'."
assert (
card_a["articles"][2]["title"] == "Article 3"
), "Card A's third article title should be 'Article 3'."

assert (
card_b is not None and "articles" in card_b
), "Card B should have nested articles."
assert len(card_b["articles"]) == 2, "Card B should have two articles."
assert (
card_b["articles"][0]["title"] == "Article 1"
), "Card B's first article title should be 'Article 1'."
assert (
card_b["articles"][1]["title"] == "Article 2"
), "Card B's second article title should be 'Article 2'."

assert (
card_c is not None and "articles" in card_c
), "Card C should have nested articles."
assert len(card_c["articles"]) == 0, "Card C should have no articles."


@pytest.mark.asyncio
async def test_get_multi_joined_sorting_nested_items_many_to_many(async_session):
project1 = Project(id=1, name="Project 1", description="First Project")
project2 = Project(id=2, name="Project 2", description="Second Project")

participant1 = Participant(id=1, name="Participant 3", role="Developer")
participant2 = Participant(id=2, name="Participant 1", role="Designer")
participant3 = Participant(id=3, name="Participant 2", role="Manager")

async_session.add_all([project1, project2, participant1, participant2, participant3])
await async_session.commit()

projects_participants1 = ProjectsParticipantsAssociation(
project_id=1, participant_id=1
)
projects_participants2 = ProjectsParticipantsAssociation(
project_id=1, participant_id=2
)
projects_participants3 = ProjectsParticipantsAssociation(
project_id=1, participant_id=3
)
projects_participants4 = ProjectsParticipantsAssociation(
project_id=2, participant_id=1
)
projects_participants5 = ProjectsParticipantsAssociation(
project_id=2, participant_id=2
)

async_session.add_all(
[
projects_participants1,
projects_participants2,
projects_participants3,
projects_participants4,
projects_participants5,
]
)
await async_session.commit()

crud_project = FastCRUD(Project)

join_condition_1 = Project.id == ProjectsParticipantsAssociation.project_id
join_condition_2 = ProjectsParticipantsAssociation.participant_id == Participant.id

joins_config = [
JoinConfig(
model=ProjectsParticipantsAssociation,
join_on=join_condition_1,
join_type="inner",
join_prefix="pp_",
),
JoinConfig(
model=Participant,
join_on=join_condition_2,
join_type="inner",
join_prefix="participant_",
relationship_type="one-to-many",
sort_columns=["name"],
sort_orders=["asc"],
),
]

records = await crud_project.get_multi_joined(
db=async_session,
nest_joins=True,
joins_config=joins_config,
)

assert records is not None, "No data returned from the database."
assert "data" in records, "Result should contain 'data' key."
data = records["data"]
assert isinstance(data, list), "Result data should be a list."
assert len(data) == 2, "Expected two project records."

project_1 = next((p for p in data if p["id"] == project1.id), None)
project_2 = next((p for p in data if p["id"] == project2.id), None)

assert (
project_1 is not None and "participants" in project_1
), "Project 1 should have nested participants."
assert len(project_1["participants"]) == 3, "Project 1 should have three participants."
assert (
project_1["participants"][0]["name"] == "Participant 1"
), "Project 1's first participant name should be 'Participant 1'."
assert (
project_1["participants"][1]["name"] == "Participant 2"
), "Project 1's second participant name should be 'Participant 2'."
assert (
project_1["participants"][2]["name"] == "Participant 3"
), "Project 1's third participant name should be 'Participant 3'."

assert (
project_2 is not None and "participants" in project_2
), "Project 2 should have nested participants."
assert len(project_2["participants"]) == 2, "Project 2 should have two participants."
assert (
project_2["participants"][0]["name"] == "Participant 1"
), "Project 2's first participant name should be 'Participant 1'."
assert (
project_2["participants"][1]["name"] == "Participant 2"
), "Project 2's second participant name should be 'Participant 2'."
Loading
Loading