Skip to content

Commit fe2aede

Browse files
committed
More robust default behavior on OrderingFilter (#4156)
1 parent dc09eef commit fe2aede

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

rest_framework/filters.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -222,24 +222,40 @@ def get_default_ordering(self, view):
222222
return (ordering,)
223223
return ordering
224224

225+
def get_default_valid_fields(self, queryset, view):
226+
# If `ordering_fields` is not specified, then we determine a default
227+
# based on the serializer class, if one exists on the view.
228+
if hasattr(view, 'get_serializer_class'):
229+
try:
230+
serializer_class = view.get_serializer_class()
231+
except AssertionError:
232+
# Raised by the default implementation if
233+
# no serializer_class was found
234+
serializer_class = None
235+
else:
236+
serializer_class = getattr(view, 'serializer_class', None)
237+
238+
if serializer_class is None:
239+
msg = (
240+
"Cannot use %s on a view which does not have either a "
241+
"'serializer_class', an overriding 'get_serializer_class' "
242+
"or 'ordering_fields' attribute."
243+
)
244+
raise ImproperlyConfigured(msg % self.__class__.__name__)
245+
246+
return [
247+
(field.source or field_name, field.label)
248+
for field_name, field in serializer_class().fields.items()
249+
if not getattr(field, 'write_only', False) and not field.source == '*'
250+
]
251+
225252
def get_valid_fields(self, queryset, view):
226253
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
227254

228255
if valid_fields is None:
229256
# Default to allowing filtering on serializer fields
230-
try:
231-
serializer_class = view.get_serializer_class()
232-
except AssertionError: # raised if no serializer_class was found
233-
msg = ("Cannot use %s on a view which does not have either a "
234-
"'serializer_class', an overriding 'get_serializer_class' "
235-
"or 'ordering_fields' attribute.")
236-
raise ImproperlyConfigured(msg % self.__class__.__name__)
257+
return self.get_default_valid_fields(queryset, view)
237258

238-
valid_fields = [
239-
(field.source or field_name, field.label)
240-
for field_name, field in serializer_class().fields.items()
241-
if not getattr(field, 'write_only', False) and not field.source == '*'
242-
]
243259
elif valid_fields == '__all__':
244260
# View explicitly allows filtering on any model field
245261
valid_fields = [

0 commit comments

Comments
 (0)