diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 90c19aba08..eaca4ca56e 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -156,26 +156,36 @@ def get_default_ordering(self, view): return ordering def remove_invalid_fields(self, queryset, fields, view): - valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) + ordering_fields = getattr(view, 'ordering_fields', self.ordering_fields) - if valid_fields is None: - # Default to allowing filtering on serializer fields - serializer_class = getattr(view, 'serializer_class') + if not ordering_fields == '__all__': + serializer_class = getattr(view, 'serializer_class') or view.get_serializer_class() if serializer_class is None: msg = ("Cannot use %s on a view which does not have either a " "'serializer_class' or 'ordering_fields' attribute.") raise ImproperlyConfigured(msg % self.__class__.__name__) + + if ordering_fields is None: + # Default to allowing filtering on serializer field names (return field sources) valid_fields = [ - field.source or field_name + (field.source, field_name) for field_name, field in serializer_class().fields.items() if not getattr(field, 'write_only', False) ] - elif valid_fields == '__all__': + return [term[0] for term in valid_fields if term[0] != "*"] + elif ordering_fields == '__all__': # View explicitly allows filtering on any model field valid_fields = [field.name for field in queryset.model._meta.fields] valid_fields += queryset.query.aggregates.keys() - - return [term for term in fields if term.lstrip('-') in valid_fields] + return [term for term in fields if term.lstrip('-') in valid_fields] + else: + # Allow filtering on defined field name (return field sources) + valid_fields = [ + (field.source, field_name) + for field_name, field in serializer_class().fields.items() + if not getattr(field, 'write_only', False) + ] + return [term[0] for term in valid_fields if term[0] != "*" and term[1].lstrip('-') in fields] def filter_queryset(self, request, queryset, view): ordering = self.get_ordering(request, queryset, view)