Skip to content

Commit 0df1258

Browse files
committed
refactor: prefer pk/id from resource data over resource instance
test: add tests for `get_resource_id_from_instance` util function
1 parent 4436507 commit 0df1258

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

rest_framework_json_api/renderers.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -443,10 +443,12 @@ def build_json_resource_obj(
443443
# Determine type from the instance if the underlying model is polymorphic
444444
if force_type_resolution:
445445
resource_name = utils.get_resource_type_from_instance(resource_instance)
446-
resource_id = utils.get_resource_id_from_instance(resource_instance)
446+
resource_id = utils.get_resource_id_from_instance(
447+
resource
448+
) or utils.get_resource_id_from_instance(resource_instance)
447449
resource_data = {
448450
"type": resource_name,
449-
"id": resource_id,
451+
"id": force_str(resource_id) if resource_id is not None else None,
450452
"attributes": cls.extract_attributes(fields, resource),
451453
}
452454
relationships = cls.extract_relationships(fields, resource, resource_instance)

rest_framework_json_api/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,9 @@ def get_resource_id_from_instance(instance):
309309
if not instance:
310310
return None
311311
elif hasattr(instance, "id"):
312-
return encoding.force_str(instance.id)
312+
return instance.id
313313
elif hasattr(instance, "pk"):
314-
return encoding.force_str(instance.pk)
314+
return instance.pk
315315
elif isinstance(instance, dict):
316316
return instance.get("id", instance.get("pk", None))
317317
return None

tests/test_utils.py

+30
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
format_resource_type,
1515
format_value,
1616
get_related_resource_type,
17+
get_resource_id_from_instance,
1718
get_resource_name,
1819
get_resource_type_from_serializer,
1920
undo_format_field_name,
@@ -392,6 +393,35 @@ class SerializerWithoutResourceName(serializers.Serializer):
392393
)
393394

394395

396+
class ObjectWithId:
397+
id: int = 9
398+
399+
400+
class ObjectWithPk:
401+
pk: int = 7
402+
403+
404+
class ObjectWithPkAndId(ObjectWithId, ObjectWithPk):
405+
pass
406+
407+
408+
@pytest.mark.parametrize(
409+
"instance, expected",
410+
[
411+
(None, None),
412+
(BasicModel(id=5), 5),
413+
(ObjectWithId(), 9),
414+
(ObjectWithPk(), 7),
415+
(ObjectWithPkAndId(), 9),
416+
({"id": 11}, 11),
417+
({"pk": 13}, 13),
418+
({"pk": 11, "id": 13}, 13),
419+
],
420+
)
421+
def test_get_resource_id_from_instance(instance, expected):
422+
assert get_resource_id_from_instance(instance) == expected
423+
424+
395425
@pytest.mark.parametrize(
396426
"message,pointer,response,result",
397427
[

0 commit comments

Comments
 (0)