-
Notifications
You must be signed in to change notification settings - Fork 26
Atlas search lookups #325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Atlas search lookups #325
Changes from all commits
b774f21
007ad74
2b78216
6211d89
1828ee4
d4c2743
23f90dd
67f0bf3
fd791f3
a7af873
67f9c86
6b99ecf
11cc8bf
6d8edba
366c151
a937cb1
c195010
99f6548
2b8e2b0
2baafcf
b491d6e
d06ca78
c035a6c
eac06d1
ed741f5
d869ed0
a08d556
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,4 +81,4 @@ repos: | |
rev: "v2.2.6" | ||
hooks: | ||
- id: codespell | ||
args: ["-L", "nin"] | ||
args: ["-L", "nin", "-L", "searchin"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
from django.utils.functional import cached_property | ||
from pymongo import ASCENDING, DESCENDING | ||
|
||
from .expressions.search import SearchExpression, SearchVector | ||
from .query import MongoQuery, wrap_database_errors | ||
|
||
|
||
|
@@ -33,6 +34,8 @@ def __init__(self, *args, **kwargs): | |
# A list of OrderBy objects for this query. | ||
self.order_by_objs = None | ||
self.subqueries = [] | ||
# Atlas search calls | ||
self.search_pipeline = [] | ||
|
||
def _get_group_alias_column(self, expr, annotation_group_idx): | ||
"""Generate a dummy field for use in the ids fields in $group.""" | ||
|
@@ -56,6 +59,29 @@ def _get_column_from_expression(self, expr, alias): | |
column_target.set_attributes_from_name(alias) | ||
return Col(self.collection_name, column_target) | ||
|
||
def _get_replace_expr(self, sub_expr, group, alias): | ||
column_target = sub_expr.output_field.clone() | ||
column_target.db_column = alias | ||
column_target.set_attributes_from_name(alias) | ||
inner_column = Col(self.collection_name, column_target) | ||
if getattr(sub_expr, "distinct", False): | ||
# If the expression should return distinct values, use | ||
# $addToSet to deduplicate. | ||
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True) | ||
group[alias] = {"$addToSet": rhs} | ||
replacing_expr = sub_expr.copy() | ||
replacing_expr.set_source_expressions([inner_column, None]) | ||
else: | ||
group[alias] = sub_expr.as_mql(self, self.connection) | ||
replacing_expr = inner_column | ||
# Count must return 0 rather than null. | ||
if isinstance(sub_expr, Count): | ||
replacing_expr = Coalesce(replacing_expr, 0) | ||
# Variance = StdDev^2 | ||
if isinstance(sub_expr, Variance): | ||
replacing_expr = Power(replacing_expr, 2) | ||
return replacing_expr | ||
|
||
def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx): | ||
""" | ||
Prepare expressions for the aggregation pipeline. | ||
|
@@ -79,29 +105,45 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group | |
alias = ( | ||
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target | ||
) | ||
column_target = sub_expr.output_field.clone() | ||
column_target.db_column = alias | ||
column_target.set_attributes_from_name(alias) | ||
inner_column = Col(self.collection_name, column_target) | ||
if sub_expr.distinct: | ||
# If the expression should return distinct values, use | ||
# $addToSet to deduplicate. | ||
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True) | ||
group[alias] = {"$addToSet": rhs} | ||
replacing_expr = sub_expr.copy() | ||
replacing_expr.set_source_expressions([inner_column, None]) | ||
else: | ||
group[alias] = sub_expr.as_mql(self, self.connection) | ||
replacing_expr = inner_column | ||
# Count must return 0 rather than null. | ||
if isinstance(sub_expr, Count): | ||
replacing_expr = Coalesce(replacing_expr, 0) | ||
# Variance = StdDev^2 | ||
if isinstance(sub_expr, Variance): | ||
replacing_expr = Power(replacing_expr, 2) | ||
replacements[sub_expr] = replacing_expr | ||
replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias) | ||
return replacements, group | ||
|
||
def _prepare_search_expressions_for_pipeline(self, expression, search_idx, replacements): | ||
searches = {} | ||
for sub_expr in self._get_search_expressions(expression): | ||
if sub_expr not in replacements: | ||
alias = f"__search_expr.search{next(search_idx)}" | ||
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias) | ||
|
||
def _prepare_search_query_for_aggregation_pipeline(self, order_by): | ||
""" | ||
Prepare expressions for the search pipeline. | ||
|
||
Handle the computation of search functions used by various | ||
expressions. Separate and create intermediate columns, and replace | ||
nodes to simulate a search operation. | ||
|
||
MongoDB's $search or $searchVector are stages. To apply operations over them, | ||
compute the $search or $vectorSearch first, then apply additional operations in a subsequent | ||
stage by replacing the aggregate expressions with new document field prefixed | ||
by `__search_expr.search#`. | ||
""" | ||
replacements = {} | ||
annotation_group_idx = itertools.count(start=1) | ||
for expr in self.query.annotation_select.values(): | ||
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements) | ||
|
||
for expr, _ in order_by: | ||
self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements) | ||
|
||
self._prepare_search_expressions_for_pipeline( | ||
self.having, annotation_group_idx, replacements | ||
) | ||
self._prepare_search_expressions_for_pipeline( | ||
self.get_where(), annotation_group_idx, replacements | ||
) | ||
return replacements | ||
|
||
def _prepare_annotations_for_aggregation_pipeline(self, order_by): | ||
"""Prepare annotations for the aggregation pipeline.""" | ||
replacements = {} | ||
|
@@ -206,9 +248,66 @@ def _build_aggregation_pipeline(self, ids, group): | |
pipeline.append({"$unset": "_id"}) | ||
return pipeline | ||
|
||
def _compound_searches_queries(self, search_replacements): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to preserve this function for the future, probably want to make hybrid search and this part of the code could be useful. I know that it is weird, check the replacement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm fine with it, please just add a docstring to explain the function and the additional comment explaining the need for the checks. |
||
""" | ||
Builds a query pipeline from a mapping of search expressions to result columns. | ||
|
||
Currently, only a single `$search` or `$vectorSearch` expression is supported. | ||
Combining multiple search expressions is not yet allowed and will raise a ValueError. | ||
|
||
This method will eventually support hybrid search by allowing the combination of | ||
`$search` and `$vectorSearch` operations. | ||
""" | ||
if not search_replacements: | ||
return [] | ||
if len(search_replacements) > 1: | ||
has_search = any(not isinstance(search, SearchVector) for search in search_replacements) | ||
has_vector_search = any( | ||
isinstance(search, SearchVector) for search in search_replacements | ||
) | ||
if has_search and has_vector_search: | ||
raise ValueError( | ||
"Cannot combine a `$vectorSearch` with a `$search` operator. " | ||
"If you need to combine them, consider restructuring your query logic or " | ||
"running them as separate queries." | ||
) | ||
if has_vector_search: | ||
raise ValueError( | ||
"Cannot combine two `$vectorSearch` operator. " | ||
"If you need to combine them, consider restructuring your query logic or " | ||
"running them as separate queries." | ||
) | ||
raise ValueError( | ||
"Only one $search operation is allowed per query. " | ||
f"Received {len(search_replacements)} search expressions. " | ||
"To combine multiple search expressions, use either a CompoundExpression for " | ||
"fine-grained control or CombinedSearchExpression for simple logical combinations." | ||
) | ||
pipeline = [] | ||
for search, result_col in search_replacements.items(): | ||
score_function = ( | ||
"vectorSearchScore" if isinstance(search, SearchVector) else "searchScore" | ||
) | ||
pipeline.extend( | ||
[ | ||
search.as_mql(self, self.connection), | ||
{ | ||
"$addFields": { | ||
result_col.as_mql(self, self.connection, as_path=True): { | ||
"$meta": score_function | ||
} | ||
} | ||
}, | ||
] | ||
) | ||
return pipeline | ||
|
||
def pre_sql_setup(self, with_col_aliases=False): | ||
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases) | ||
group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by) | ||
search_replacements = self._prepare_search_query_for_aggregation_pipeline(order_by) | ||
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by) | ||
all_replacements = {**search_replacements, **group_replacements} | ||
self.search_pipeline = self._compound_searches_queries(search_replacements) | ||
Comment on lines
+309
to
+310
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why don't we pass the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could. But then It will need to filter them to check that it has one search or vector search operator. It composes only search operators that are store in the replacements. ( |
||
# query.group_by is either: | ||
# - None: no GROUP BY | ||
# - True: group by select fields | ||
|
@@ -233,6 +332,8 @@ def pre_sql_setup(self, with_col_aliases=False): | |
for target, expr in self.query.annotation_select.items() | ||
} | ||
self.order_by_objs = [expr.replace_expressions(all_replacements) for expr, _ in order_by] | ||
if (where := self.get_where()) and search_replacements: | ||
self.set_where(where.replace_expressions(search_replacements)) | ||
return extra_select, order_by, group_by | ||
|
||
def execute_sql( | ||
|
@@ -555,10 +656,16 @@ def get_lookup_pipeline(self): | |
return result | ||
|
||
def _get_aggregate_expressions(self, expr): | ||
return self._get_all_expressions_of_type(expr, Aggregate) | ||
|
||
def _get_search_expressions(self, expr): | ||
return self._get_all_expressions_of_type(expr, SearchExpression) | ||
|
||
def _get_all_expressions_of_type(self, expr, target_type): | ||
stack = [expr] | ||
while stack: | ||
expr = stack.pop() | ||
if isinstance(expr, Aggregate): | ||
if isinstance(expr, target_type): | ||
yield expr | ||
elif hasattr(expr, "get_source_expressions"): | ||
stack.extend(expr.get_source_expressions()) | ||
|
@@ -627,6 +734,9 @@ def _get_ordering(self): | |
def get_where(self): | ||
return getattr(self, "where", self.query.where) | ||
|
||
def set_where(self, value): | ||
self.where = value | ||
|
||
def explain_query(self): | ||
# Validate format (none supported) and options. | ||
options = self.connection.ops.explain_query_prefix( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from .search import ( | ||
CombinedSearchExpression, | ||
CompoundExpression, | ||
SearchAutocomplete, | ||
SearchEquals, | ||
SearchExists, | ||
SearchGeoShape, | ||
SearchGeoWithin, | ||
SearchIn, | ||
SearchMoreLikeThis, | ||
SearchPhrase, | ||
SearchQueryString, | ||
SearchRange, | ||
SearchRegex, | ||
SearchScoreOption, | ||
SearchText, | ||
SearchVector, | ||
SearchWildcard, | ||
) | ||
|
||
__all__ = [ | ||
"CombinedSearchExpression", | ||
"CompoundExpression", | ||
"SearchAutocomplete", | ||
"SearchEquals", | ||
"SearchExists", | ||
"SearchGeoShape", | ||
"SearchGeoWithin", | ||
"SearchIn", | ||
"SearchMoreLikeThis", | ||
"SearchPhrase", | ||
"SearchQueryString", | ||
"SearchRange", | ||
"SearchRegex", | ||
"SearchScoreOption", | ||
"SearchText", | ||
"SearchVector", | ||
"SearchWildcard", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know these are private functions, but can they get a docstring? Same with
_get_replace_expr
. It's quite complex code so it becomes harder to follow.