Skip to content

Commit 7360368

Browse files
committed
Add a failing test for Generic Serializing
1 parent 0ecebf2 commit 7360368

File tree

2 files changed

+50
-6
lines changed

2 files changed

+50
-6
lines changed

src/guidellm/benchmark/scheduler_registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ def scheduler_register_benchmark_objects():
1616
SchedulerMessagingPydanticRegistry.register("GenerationRequestTimings")(
1717
GenerationRequestTimings
1818
)
19-
SchedulerMessagingPydanticRegistry.register("ScheduledRequestInfo")(
20-
ScheduledRequestInfo
21-
)
19+
SchedulerMessagingPydanticRegistry.register(
20+
"ScheduledRequestInfo[GenerationRequestTimings]"
21+
)(ScheduledRequestInfo[GenerationRequestTimings])

tests/unit/utils/test_encoding.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import uuid
4-
from typing import Any, Generic
4+
from typing import Any, Generic, TypeVar
55

66
import pytest
77
from pydantic import BaseModel, Field
@@ -22,12 +22,28 @@ class SampleModel(BaseModel):
2222
value: int = Field(description="Value field for testing")
2323

2424

25-
class ComplexModel(BaseModel):
25+
class SampleModelSubclass(SampleModel):
26+
"""Subclass of SampleModel for testing."""
27+
28+
extra_field: str
29+
30+
31+
SampleModelT = TypeVar("SampleModelT", bound=SampleModel)
32+
33+
34+
class ComplexModel(BaseModel, Generic[SampleModelT]):
2635
"""Complex Pydantic model for testing."""
2736

2837
items: list[str] = Field(default_factory=list)
2938
metadata: dict[str, Any] = Field(default_factory=dict)
30-
nested: SampleModel | None = Field(default=None)
39+
nested: SampleModelT | None = Field(default=None)
40+
41+
42+
class GenricModelWrapper(Generic[SampleModelT]):
43+
"""Simulates a layered generic type."""
44+
45+
def method(self, **kwargs) -> ComplexModel[SampleModelT]:
46+
return ComplexModel[SampleModelT](**kwargs)
3147

3248

3349
class TestMessageEncoding:
@@ -508,3 +524,31 @@ def test_dynamic_import_load_pydantic(self, monkeypatch):
508524
inst.pydantic_registry.clear()
509525
restored = inst.from_dict(dumped)
510526
assert restored == sample
527+
528+
@pytest.mark.sanity
529+
def test_generic_model(self):
530+
inst = Serializer("dict")
531+
inst.register_pydantic(ComplexModel[SampleModelSubclass])
532+
nested = ComplexModel[SampleModelSubclass](
533+
items=["i1", "i2"],
534+
metadata={"m": 1},
535+
nested=SampleModelSubclass(name="nested", value=10, extra_field="extra"),
536+
)
537+
dumped = inst.to_dict(nested)
538+
restored = inst.from_dict(dumped)
539+
assert restored == nested
540+
541+
@pytest.mark.sanity
542+
def test_generic_emitted_type(self):
543+
generic_instance = GenricModelWrapper[SampleModelSubclass]()
544+
545+
inst = Serializer("dict")
546+
inst.register_pydantic(ComplexModel[SampleModelSubclass])
547+
nested = generic_instance.method(
548+
items=["i1", "i2"],
549+
metadata={"m": 1},
550+
nested=SampleModelSubclass(name="nested", value=10, extra_field="extra"),
551+
)
552+
dumped = inst.to_dict(nested)
553+
restored = inst.from_dict(dumped)
554+
assert restored == nested

0 commit comments

Comments
 (0)