Skip to content

Atlas search lookups #325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ repos:
rev: "v2.2.6"
hooks:
- id: codespell
args: ["-L", "nin"]
args: ["-L", "nin", "-L", "searchin"]
2 changes: 1 addition & 1 deletion django_mongodb_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .aggregates import register_aggregates # noqa: E402
from .checks import register_checks # noqa: E402
from .expressions import register_expressions # noqa: E402
from .expressions.builtins import register_expressions # noqa: E402
from .fields import register_fields # noqa: E402
from .functions import register_functions # noqa: E402
from .indexes import register_indexes # noqa: E402
Expand Down
156 changes: 133 additions & 23 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from django.utils.functional import cached_property
from pymongo import ASCENDING, DESCENDING

from .expressions.search import SearchExpression, SearchVector
from .query import MongoQuery, wrap_database_errors


Expand All @@ -33,6 +34,8 @@ def __init__(self, *args, **kwargs):
# A list of OrderBy objects for this query.
self.order_by_objs = None
self.subqueries = []
# Atlas search calls
self.search_pipeline = []

def _get_group_alias_column(self, expr, annotation_group_idx):
"""Generate a dummy field for use in the ids fields in $group."""
Expand All @@ -56,6 +59,29 @@ def _get_column_from_expression(self, expr, alias):
column_target.set_attributes_from_name(alias)
return Col(self.collection_name, column_target)

def _get_replace_expr(self, sub_expr, group, alias):
column_target = sub_expr.output_field.clone()
column_target.db_column = alias
column_target.set_attributes_from_name(alias)
inner_column = Col(self.collection_name, column_target)
if getattr(sub_expr, "distinct", False):
# If the expression should return distinct values, use
# $addToSet to deduplicate.
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
group[alias] = {"$addToSet": rhs}
replacing_expr = sub_expr.copy()
replacing_expr.set_source_expressions([inner_column, None])
else:
group[alias] = sub_expr.as_mql(self, self.connection)
replacing_expr = inner_column
# Count must return 0 rather than null.
if isinstance(sub_expr, Count):
replacing_expr = Coalesce(replacing_expr, 0)
# Variance = StdDev^2
if isinstance(sub_expr, Variance):
replacing_expr = Power(replacing_expr, 2)
return replacing_expr

def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx):
"""
Prepare expressions for the aggregation pipeline.
Expand All @@ -79,29 +105,45 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
alias = (
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
)
column_target = sub_expr.output_field.clone()
column_target.db_column = alias
column_target.set_attributes_from_name(alias)
inner_column = Col(self.collection_name, column_target)
if sub_expr.distinct:
# If the expression should return distinct values, use
# $addToSet to deduplicate.
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
group[alias] = {"$addToSet": rhs}
replacing_expr = sub_expr.copy()
replacing_expr.set_source_expressions([inner_column, None])
else:
group[alias] = sub_expr.as_mql(self, self.connection)
replacing_expr = inner_column
# Count must return 0 rather than null.
if isinstance(sub_expr, Count):
replacing_expr = Coalesce(replacing_expr, 0)
# Variance = StdDev^2
if isinstance(sub_expr, Variance):
replacing_expr = Power(replacing_expr, 2)
replacements[sub_expr] = replacing_expr
replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias)
return replacements, group

def _prepare_search_expressions_for_pipeline(self, expression, search_idx, replacements):
searches = {}
for sub_expr in self._get_search_expressions(expression):
if sub_expr not in replacements:
alias = f"__search_expr.search{next(search_idx)}"
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)

def _prepare_search_query_for_aggregation_pipeline(self, order_by):
Comment on lines +111 to +118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know these are private functions, but can they get a docstring? Same with _get_replace_expr. It's quite complex code so it becomes harder to follow.

"""
Prepare expressions for the search pipeline.

Handle the computation of search functions used by various
expressions. Separate and create intermediate columns, and replace
nodes to simulate a search operation.

MongoDB's $search or $searchVector are stages. To apply operations over them,
compute the $search or $vectorSearch first, then apply additional operations in a subsequent
stage by replacing the aggregate expressions with new document field prefixed
by `__search_expr.search#`.
"""
replacements = {}
annotation_group_idx = itertools.count(start=1)
for expr in self.query.annotation_select.values():
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)

for expr, _ in order_by:
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements)

self._prepare_search_expressions_for_pipeline(
self.having, annotation_group_idx, replacements
)
self._prepare_search_expressions_for_pipeline(
self.get_where(), annotation_group_idx, replacements
)
return replacements

def _prepare_annotations_for_aggregation_pipeline(self, order_by):
"""Prepare annotations for the aggregation pipeline."""
replacements = {}
Expand Down Expand Up @@ -206,9 +248,66 @@ def _build_aggregation_pipeline(self, ids, group):
pipeline.append({"$unset": "_id"})
return pipeline

def _compound_searches_queries(self, search_replacements):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to preserve this function for the future, probably want to make hybrid search and this part of the code could be useful. I know that it is weird, check the replacement len as 1 and then iterate over it. Also the exception could be raised before this point. Let me know if you want me to refactor this code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with it, please just add a docstring to explain the function and the additional comment explaining the need for the checks.

"""
Builds a query pipeline from a mapping of search expressions to result columns.

Currently, only a single `$search` or `$vectorSearch` expression is supported.
Combining multiple search expressions is not yet allowed and will raise a ValueError.

This method will eventually support hybrid search by allowing the combination of
`$search` and `$vectorSearch` operations.
"""
if not search_replacements:
return []
if len(search_replacements) > 1:
has_search = any(not isinstance(search, SearchVector) for search in search_replacements)
has_vector_search = any(
isinstance(search, SearchVector) for search in search_replacements
)
if has_search and has_vector_search:
raise ValueError(
"Cannot combine a `$vectorSearch` with a `$search` operator. "
"If you need to combine them, consider restructuring your query logic or "
"running them as separate queries."
)
if has_vector_search:
raise ValueError(
"Cannot combine two `$vectorSearch` operator. "
"If you need to combine them, consider restructuring your query logic or "
"running them as separate queries."
)
raise ValueError(
"Only one $search operation is allowed per query. "
f"Received {len(search_replacements)} search expressions. "
"To combine multiple search expressions, use either a CompoundExpression for "
"fine-grained control or CombinedSearchExpression for simple logical combinations."
)
pipeline = []
for search, result_col in search_replacements.items():
score_function = (
"vectorSearchScore" if isinstance(search, SearchVector) else "searchScore"
)
pipeline.extend(
[
search.as_mql(self, self.connection),
{
"$addFields": {
result_col.as_mql(self, self.connection, as_path=True): {
"$meta": score_function
}
}
},
]
)
return pipeline

def pre_sql_setup(self, with_col_aliases=False):
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
search_replacements = self._prepare_search_query_for_aggregation_pipeline(order_by)
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
all_replacements = {**search_replacements, **group_replacements}
self.search_pipeline = self._compound_searches_queries(search_replacements)
Comment on lines +309 to +310
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we pass the all_replacements into self._compound_searches_queries?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could. But then It will need to filter them to check that it has one search or vector search operator. It composes only search operators that are store in the replacements. (search_replacements has multiple uses, I can try to refactor it a bit)

# query.group_by is either:
# - None: no GROUP BY
# - True: group by select fields
Expand All @@ -233,6 +332,8 @@ def pre_sql_setup(self, with_col_aliases=False):
for target, expr in self.query.annotation_select.items()
}
self.order_by_objs = [expr.replace_expressions(all_replacements) for expr, _ in order_by]
if (where := self.get_where()) and search_replacements:
self.set_where(where.replace_expressions(search_replacements))
return extra_select, order_by, group_by

def execute_sql(
Expand Down Expand Up @@ -555,10 +656,16 @@ def get_lookup_pipeline(self):
return result

def _get_aggregate_expressions(self, expr):
return self._get_all_expressions_of_type(expr, Aggregate)

def _get_search_expressions(self, expr):
return self._get_all_expressions_of_type(expr, SearchExpression)

def _get_all_expressions_of_type(self, expr, target_type):
stack = [expr]
while stack:
expr = stack.pop()
if isinstance(expr, Aggregate):
if isinstance(expr, target_type):
yield expr
elif hasattr(expr, "get_source_expressions"):
stack.extend(expr.get_source_expressions())
Expand Down Expand Up @@ -627,6 +734,9 @@ def _get_ordering(self):
def get_where(self):
return getattr(self, "where", self.query.where)

def set_where(self, value):
self.where = value

def explain_query(self):
# Validate format (none supported) and options.
options = self.connection.ops.explain_query_prefix(
Expand Down
39 changes: 39 additions & 0 deletions django_mongodb_backend/expressions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from .search import (
CombinedSearchExpression,
CompoundExpression,
SearchAutocomplete,
SearchEquals,
SearchExists,
SearchGeoShape,
SearchGeoWithin,
SearchIn,
SearchMoreLikeThis,
SearchPhrase,
SearchQueryString,
SearchRange,
SearchRegex,
SearchScoreOption,
SearchText,
SearchVector,
SearchWildcard,
)

__all__ = [
"CombinedSearchExpression",
"CompoundExpression",
"SearchAutocomplete",
"SearchEquals",
"SearchExists",
"SearchGeoShape",
"SearchGeoWithin",
"SearchIn",
"SearchMoreLikeThis",
"SearchPhrase",
"SearchQueryString",
"SearchRange",
"SearchRegex",
"SearchScoreOption",
"SearchText",
"SearchVector",
"SearchWildcard",
]
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from django.db.models.sql import Query

from .query_utils import process_lhs
from ..query_utils import process_lhs


def case(self, compiler, connection):
Expand Down Expand Up @@ -53,7 +53,7 @@ def case(self, compiler, connection):
}


def col(self, compiler, connection): # noqa: ARG001
def col(self, compiler, connection, as_path=False): # noqa: ARG001
# If the column is part of a subquery and belongs to one of the parent
# queries, it will be stored for reference using $let in a $lookup stage.
# If the query is built with `alias_cols=False`, treat the column as
Expand All @@ -71,7 +71,9 @@ def col(self, compiler, connection): # noqa: ARG001
# Add the column's collection's alias for columns in joined collections.
has_alias = self.alias and self.alias != compiler.collection_name
prefix = f"{self.alias}." if has_alias else ""
return f"${prefix}{self.target.column}"
if not as_path:
prefix = f"${prefix}"
return f"{prefix}{self.target.column}"


def col_pairs(self, compiler, connection):
Expand Down
Loading