Skip to content

Commit 37b4d42

Browse files
committed
Merge pull request #3225 from tomchristie/maxpeterson-grouped-choices-fix
Support grouped choices.
2 parents 9a77879 + 33d6d4a commit 37b4d42

File tree

11 files changed

+267
-29
lines changed

11 files changed

+267
-29
lines changed

rest_framework/fields.py

+81-11
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,53 @@ def set_value(dictionary, keys, value):
108108
dictionary[keys[-1]] = value
109109

110110

111+
def to_choices_dict(choices):
112+
"""
113+
Convert choices into key/value dicts.
114+
115+
pairwise_choices([1]) -> {1: 1}
116+
pairwise_choices([(1, '1st'), (2, '2nd')]) -> {1: '1st', 2: '2nd'}
117+
pairwise_choices([('Group', ((1, '1st'), 2))]) -> {'Group': {1: '1st', 2: '2nd'}}
118+
"""
119+
# Allow single, paired or grouped choices style:
120+
# choices = [1, 2, 3]
121+
# choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
122+
# choices = [('Category', ((1, 'First'), (2, 'Second'))), (3, 'Third')]
123+
ret = OrderedDict()
124+
for choice in choices:
125+
if (not isinstance(choice, (list, tuple))):
126+
# single choice
127+
ret[choice] = choice
128+
else:
129+
key, value = choice
130+
if isinstance(value, (list, tuple)):
131+
# grouped choices (category, sub choices)
132+
ret[key] = to_choices_dict(value)
133+
else:
134+
# paired choice (key, display value)
135+
ret[key] = value
136+
return ret
137+
138+
139+
def flatten_choices_dict(choices):
140+
"""
141+
Convert a group choices dict into a flat dict of choices.
142+
143+
flatten_choices({1: '1st', 2: '2nd'}) -> {1: '1st', 2: '2nd'}
144+
flatten_choices({'Group': {1: '1st', 2: '2nd'}}) -> {1: '1st', 2: '2nd'}
145+
"""
146+
ret = OrderedDict()
147+
for key, value in choices.items():
148+
if isinstance(value, dict):
149+
# grouped choices (category, sub choices)
150+
for sub_key, sub_value in value.items():
151+
ret[sub_key] = sub_value
152+
else:
153+
# choice (key, display value)
154+
ret[key] = value
155+
return ret
156+
157+
111158
class CreateOnlyDefault(object):
112159
"""
113160
This class may be used to provide default values that are only used
@@ -1111,17 +1158,8 @@ class ChoiceField(Field):
11111158
}
11121159

11131160
def __init__(self, choices, **kwargs):
1114-
# Allow either single or paired choices style:
1115-
# choices = [1, 2, 3]
1116-
# choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
1117-
pairs = [
1118-
isinstance(item, (list, tuple)) and len(item) == 2
1119-
for item in choices
1120-
]
1121-
if all(pairs):
1122-
self.choices = OrderedDict([(key, display_value) for key, display_value in choices])
1123-
else:
1124-
self.choices = OrderedDict([(item, item) for item in choices])
1161+
self.grouped_choices = to_choices_dict(choices)
1162+
self.choices = flatten_choices_dict(self.grouped_choices)
11251163

11261164
# Map the string representation of choices to the underlying value.
11271165
# Allows us to deal with eg. integer choices while supporting either
@@ -1148,6 +1186,38 @@ def to_representation(self, value):
11481186
return value
11491187
return self.choice_strings_to_values.get(six.text_type(value), value)
11501188

1189+
def iter_options(self):
1190+
"""
1191+
Helper method for use with templates rendering select widgets.
1192+
"""
1193+
class StartOptionGroup(object):
1194+
start_option_group = True
1195+
end_option_group = False
1196+
1197+
def __init__(self, label):
1198+
self.label = label
1199+
1200+
class EndOptionGroup(object):
1201+
start_option_group = False
1202+
end_option_group = True
1203+
1204+
class Option(object):
1205+
start_option_group = False
1206+
end_option_group = False
1207+
1208+
def __init__(self, value, display_text):
1209+
self.value = value
1210+
self.display_text = display_text
1211+
1212+
for key, value in self.grouped_choices.items():
1213+
if isinstance(value, dict):
1214+
yield StartOptionGroup(label=key)
1215+
for sub_key, sub_value in value.items():
1216+
yield Option(value=sub_key, display_text=sub_value)
1217+
yield EndOptionGroup()
1218+
else:
1219+
yield Option(value=key, display_text=value)
1220+
11511221

11521222
class MultipleChoiceField(ChoiceField):
11531223
default_error_messages = {

rest_framework/templates/rest_framework/horizontal/select.html

+8-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,14 @@
1010
{% if field.allow_null or field.allow_blank %}
1111
<option value="" {% if not field.value %}selected{% endif %}>--------</option>
1212
{% endif %}
13-
{% for key, text in field.choices.items %}
14-
<option value="{{ key }}" {% if key == field.value %}selected{% endif %}>{{ text }}</option>
13+
{% for select in field.iter_options %}
14+
{% if select.start_option_group %}
15+
<optgroup label="{{ select.label }}">
16+
{% elif select.end_option_group %}
17+
</optgroup>
18+
{% else %}
19+
<option value="{{ select.value }}" {% if select.value == field.value %}selected{% endif %}>{{ select.display_text }}</option>
20+
{% endif %}
1521
{% endfor %}
1622
</select>
1723

rest_framework/templates/rest_framework/horizontal/select_multiple.html

+8-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,14 @@
1010

1111
<div class="col-sm-10">
1212
<select multiple {{ field.choices|yesno:",disabled" }} class="form-control" name="{{ field.name }}">
13-
{% for key, text in field.choices.items %}
14-
<option value="{{ key }}" {% if key in field.value %}selected{% endif %}>{{ text }}</option>
13+
{% for select in field.iter_options %}
14+
{% if select.start_option_group %}
15+
<optgroup label="{{ select.label }}">
16+
{% elif select.end_option_group %}
17+
</optgroup>
18+
{% else %}
19+
<option value="{{ select.value }}" {% if select.value == field.value %}selected{% endif %}>{{ select.display_text }}</option>
20+
{% endif %}
1521
{% empty %}
1622
<option>{{ no_items }}</option>
1723
{% endfor %}

rest_framework/templates/rest_framework/inline/select.html

+8-3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@
99
{% if field.allow_null or field.allow_blank %}
1010
<option value="" {% if not field.value %}selected{% endif %}>--------</option>
1111
{% endif %}
12-
13-
{% for key, text in field.choices.items %}
14-
<option value="{{ key }}" {% if key == field.value %}selected{% endif %}>{{ text }}</option>
12+
{% for select in field.iter_options %}
13+
{% if select.start_option_group %}
14+
<optgroup label="{{ select.label }}">
15+
{% elif select.end_option_group %}
16+
</optgroup>
17+
{% else %}
18+
<option value="{{ select.value }}" {% if select.value == field.value %}selected{% endif %}>{{ select.display_text }}</option>
19+
{% endif %}
1520
{% endfor %}
1621
</select>
1722
</div>

rest_framework/templates/rest_framework/inline/select_multiple.html

+9-3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,15 @@
99
{% endif %}
1010

1111
<select multiple {{ field.choices|yesno:",disabled" }} class="form-control" name="{{ field.name }}">
12-
{% for key, text in field.choices.items %}
13-
<option value="{{ key }}" {% if key in field.value %}selected{% endif %}>{{ text }}</option>
14-
{% empty %}
12+
{% for select in field.iter_options %}
13+
{% if select.start_option_group %}
14+
<optgroup label="{{ select.label }}">
15+
{% elif select.end_option_group %}
16+
</optgroup>
17+
{% else %}
18+
<option value="{{ select.value }}" {% if select.value == field.value %}selected{% endif %}>{{ select.display_text }}</option>
19+
{% endif %}
20+
{% empty %}
1521
<option>{{ no_items }}</option>
1622
{% endfor %}
1723
</select>

rest_framework/templates/rest_framework/vertical/select.html

+8-3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@
99
{% if field.allow_null or field.allow_blank %}
1010
<option value="" {% if not field.value %}selected{% endif %}>--------</option>
1111
{% endif %}
12-
13-
{% for key, text in field.choices.items %}
14-
<option value="{{ key }}" {% if key == field.value %}selected{% endif %}>{{ text }}</option>
12+
{% for select in field.iter_options %}
13+
{% if select.start_option_group %}
14+
<optgroup label="{{ select.label }}">
15+
{% elif select.end_option_group %}
16+
</optgroup>
17+
{% else %}
18+
<option value="{{ select.value }}" {% if select.value == field.value %}selected{% endif %}>{{ select.display_text }}</option>
19+
{% endif %}
1520
{% endfor %}
1621
</select>
1722

rest_framework/templates/rest_framework/vertical/select_multiple.html

+8-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,14 @@
99
{% endif %}
1010

1111
<select multiple {{ field.choices|yesno:",disabled" }} class="form-control" name="{{ field.name }}">
12-
{% for key, text in field.choices.items %}
13-
<option value="{{ key }}" {% if key in field.value %}selected{% endif %}>{{ text }}</option>
12+
{% for select in field.iter_options %}
13+
{% if select.start_option_group %}
14+
<optgroup label="{{ select.label }}">
15+
{% elif select.end_option_group %}
16+
</optgroup>
17+
{% else %}
18+
<option value="{{ select.value }}" {% if select.value == field.value %}selected{% endif %}>{{ select.display_text }}</option>
19+
{% endif %}
1420
{% empty %}
1521
<option>{{ no_items }}</option>
1622
{% endfor %}

rest_framework/utils/field_mapping.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ def get_field_kwargs(field_name, model_field):
107107
isinstance(model_field, models.TextField)):
108108
kwargs['allow_blank'] = True
109109

110-
if model_field.flatchoices:
110+
if model_field.choices:
111111
# If this model field contains choices, then return early.
112112
# Further keyword arguments are not valid.
113-
kwargs['choices'] = model_field.flatchoices
113+
kwargs['choices'] = model_field.choices
114114
return kwargs
115115

116116
# Ensure that max_length is passed explicitly as a keyword arg,

tests/test_fields.py

+88
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,34 @@ def test_allow_null(self):
11071107
output = field.run_validation(None)
11081108
assert output is None
11091109

1110+
def test_iter_options(self):
1111+
"""
1112+
iter_options() should return a list of options and option groups.
1113+
"""
1114+
field = serializers.ChoiceField(
1115+
choices=[
1116+
('Numbers', ['integer', 'float']),
1117+
('Strings', ['text', 'email', 'url']),
1118+
'boolean'
1119+
]
1120+
)
1121+
items = list(field.iter_options())
1122+
1123+
assert items[0].start_option_group
1124+
assert items[0].label == 'Numbers'
1125+
assert items[1].value == 'integer'
1126+
assert items[2].value == 'float'
1127+
assert items[3].end_option_group
1128+
1129+
assert items[4].start_option_group
1130+
assert items[4].label == 'Strings'
1131+
assert items[5].value == 'text'
1132+
assert items[6].value == 'email'
1133+
assert items[7].value == 'url'
1134+
assert items[8].end_option_group
1135+
1136+
assert items[9].value == 'boolean'
1137+
11101138

11111139
class TestChoiceFieldWithType(FieldValues):
11121140
"""
@@ -1153,6 +1181,66 @@ class TestChoiceFieldWithListChoices(FieldValues):
11531181
field = serializers.ChoiceField(choices=('poor', 'medium', 'good'))
11541182

11551183

1184+
class TestChoiceFieldWithGroupedChoices(FieldValues):
1185+
"""
1186+
Valid and invalid values for a `Choice` field that uses a grouped list for the
1187+
choices, rather than a list of pairs of (`value`, `description`).
1188+
"""
1189+
valid_inputs = {
1190+
'poor': 'poor',
1191+
'medium': 'medium',
1192+
'good': 'good',
1193+
}
1194+
invalid_inputs = {
1195+
'awful': ['"awful" is not a valid choice.']
1196+
}
1197+
outputs = {
1198+
'good': 'good'
1199+
}
1200+
field = serializers.ChoiceField(
1201+
choices=[
1202+
(
1203+
'Category',
1204+
(
1205+
('poor', 'Poor quality'),
1206+
('medium', 'Medium quality'),
1207+
),
1208+
),
1209+
('good', 'Good quality'),
1210+
]
1211+
)
1212+
1213+
1214+
class TestChoiceFieldWithMixedChoices(FieldValues):
1215+
"""
1216+
Valid and invalid values for a `Choice` field that uses a single paired or
1217+
grouped.
1218+
"""
1219+
valid_inputs = {
1220+
'poor': 'poor',
1221+
'medium': 'medium',
1222+
'good': 'good',
1223+
}
1224+
invalid_inputs = {
1225+
'awful': ['"awful" is not a valid choice.']
1226+
}
1227+
outputs = {
1228+
'good': 'good'
1229+
}
1230+
field = serializers.ChoiceField(
1231+
choices=[
1232+
(
1233+
'Category',
1234+
(
1235+
('poor', 'Poor quality'),
1236+
),
1237+
),
1238+
'medium',
1239+
('good', 'Good quality'),
1240+
]
1241+
)
1242+
1243+
11561244
class TestMultipleChoiceField(FieldValues):
11571245
"""
11581246
Valid and invalid values for `MultipleChoiceField`.

tests/test_model_serializer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class Meta:
181181
null_field = IntegerField(allow_null=True, required=False)
182182
default_field = IntegerField(required=False)
183183
descriptive_field = IntegerField(help_text='Some help text', label='A label')
184-
choices_field = ChoiceField(choices=[('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')])
184+
choices_field = ChoiceField(choices=(('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')))
185185
""")
186186
if six.PY2:
187187
# This particular case is too awkward to resolve fully across

0 commit comments

Comments
 (0)