Skip to content

Commit 2b9483a

Browse files
committed
refactor: add get_id support for resource identifier
1 parent 0df1258 commit 2b9483a

File tree

4 files changed

+67
-24
lines changed

4 files changed

+67
-24
lines changed

rest_framework_json_api/renderers.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -443,12 +443,9 @@ def build_json_resource_obj(
443443
# Determine type from the instance if the underlying model is polymorphic
444444
if force_type_resolution:
445445
resource_name = utils.get_resource_type_from_instance(resource_instance)
446-
resource_id = utils.get_resource_id_from_instance(
447-
resource
448-
) or utils.get_resource_id_from_instance(resource_instance)
449446
resource_data = {
450447
"type": resource_name,
451-
"id": force_str(resource_id) if resource_id is not None else None,
448+
"id": utils.get_resource_id_from_instance(resource_instance, resource),
452449
"attributes": cls.extract_attributes(fields, resource),
453450
}
454451
relationships = cls.extract_relationships(fields, resource, resource_instance)

rest_framework_json_api/utils.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -304,16 +304,18 @@ def get_resource_type_from_serializer(serializer):
304304
)
305305

306306

307-
def get_resource_id_from_instance(instance):
307+
def get_resource_id_from_instance(resource_instance, resource):
308308
"""Returns the resource identifier for a given instance (`id` takes priority over `pk`)."""
309-
if not instance:
310-
return None
311-
elif hasattr(instance, "id"):
312-
return instance.id
313-
elif hasattr(instance, "pk"):
314-
return instance.pk
315-
elif isinstance(instance, dict):
316-
return instance.get("id", instance.get("pk", None))
309+
if resource and "id" in resource:
310+
return resource["id"] and encoding.force_str(resource["id"]) or None
311+
if resource_instance:
312+
if hasattr(resource_instance, "get_id") and callable(resource_instance.get_id):
313+
return encoding.force_str(resource_instance.get_id())
314+
return (
315+
hasattr(resource_instance, "pk")
316+
and encoding.force_str(resource_instance.pk)
317+
or None
318+
)
317319
return None
318320

319321

tests/test_utils.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -406,20 +406,23 @@ class ObjectWithPkAndId(ObjectWithId, ObjectWithPk):
406406

407407

408408
@pytest.mark.parametrize(
409-
"instance, expected",
409+
"resource_instance, resource, expected",
410410
[
411-
(None, None),
412-
(BasicModel(id=5), 5),
413-
(ObjectWithId(), 9),
414-
(ObjectWithPk(), 7),
415-
(ObjectWithPkAndId(), 9),
416-
({"id": 11}, 11),
417-
({"pk": 13}, 13),
418-
({"pk": 11, "id": 13}, 13),
411+
(None, None, None),
412+
(object(), {}, None),
413+
(BasicModel(id=5), None, "5"),
414+
(ObjectWithId(), {}, "9"),
415+
(ObjectWithPk(), None, "7"),
416+
(ObjectWithPkAndId(), None, "9"),
417+
(None, {"id": 11}, "11"),
418+
(object(), {"pk": 11}, None),
419+
(ObjectWithId(), {"id": 11}, "11"),
420+
(ObjectWithPk(), {"pk": 13}, "7"),
421+
(ObjectWithPkAndId(), {"id": 12, "pk": 13}, "12"),
419422
],
420423
)
421-
def test_get_resource_id_from_instance(instance, expected):
422-
assert get_resource_id_from_instance(instance) == expected
424+
def test_get_resource_id_from_instance(resource_instance, resource, expected):
425+
assert get_resource_id_from_instance(resource_instance, resource) == expected
423426

424427

425428
@pytest.mark.parametrize(

tests/test_views.py

+41
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,28 @@ def test_post_with_missing_id(self, client):
205205
}
206206
}
207207

208+
@pytest.mark.urls(__name__)
209+
def test_custom_id(self, client):
210+
data = {
211+
"data": {
212+
"id": 2_193_102,
213+
"type": "custom",
214+
"attributes": {"body": "hello"},
215+
}
216+
}
217+
218+
url = reverse("custom-id")
219+
220+
response = client.patch(url, data=data)
221+
assert response.status_code == status.HTTP_200_OK
222+
assert response.json() == {
223+
"data": {
224+
"type": "custom",
225+
"id": "2176ce", # get_id() -> hex
226+
"attributes": {"body": "hello"},
227+
}
228+
}
229+
208230

209231
# Routing setup
210232

@@ -224,6 +246,14 @@ class CustomModelSerializer(serializers.Serializer):
224246
id = serializers.IntegerField()
225247

226248

249+
class CustomIdModelSerializer(serializers.Serializer):
250+
body = serializers.CharField()
251+
id = serializers.IntegerField()
252+
253+
def get_id(self):
254+
return hex(self.validated_data["id"])[2:]
255+
256+
227257
class CustomAPIView(APIView):
228258
parser_classes = [JSONParser]
229259
renderer_classes = [JSONRenderer]
@@ -238,10 +268,21 @@ def post(self, request, *args, **kwargs):
238268
return Response(status=status.HTTP_200_OK, data=serializer.data)
239269

240270

271+
class CustomIdAPIView(APIView):
272+
parser_classes = [JSONParser]
273+
renderer_classes = [JSONRenderer]
274+
resource_name = "custom"
275+
276+
def patch(self, request, *args, **kwargs):
277+
serializer = CustomIdModelSerializer(CustomModel(request.data))
278+
return Response(status=status.HTTP_200_OK, data=serializer.data)
279+
280+
241281
router = SimpleRouter()
242282
router.register(r"basic_models", BasicModelViewSet, basename="basic-model")
243283

244284
urlpatterns = [
245285
path("custom", CustomAPIView.as_view(), name="custom"),
286+
path("custom-id", CustomIdAPIView.as_view(), name="custom-id"),
246287
]
247288
urlpatterns += router.urls

0 commit comments

Comments
 (0)