Skip to content

Commit ab6dbe9

Browse files
authored
Merge branch 'master' into patch-1
2 parents f2a219c + 50c38a5 commit ab6dbe9

File tree

2 files changed

+59
-11
lines changed

2 files changed

+59
-11
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,9 @@ class User:
3434
})
3535
Schema: ClassVar[Type[Schema]] = Schema # For the type checker
3636
"""
37-
import dataclasses
3837
import inspect
39-
import typing
40-
from functools import lru_cache
4138
from enum import EnumMeta
39+
from functools import lru_cache
4240
from typing import (
4341
overload,
4442
Dict,
@@ -55,6 +53,7 @@ class User:
5553
Set,
5654
)
5755

56+
import dataclasses
5857
import marshmallow
5958
import typing_inspect
6059

@@ -63,9 +62,6 @@ class User:
6362
NoneType = type(None)
6463
_U = TypeVar("_U")
6564

66-
T_Schema = TypeVar("T_Schema", bound=marshmallow.Schema)
67-
T_SchemaType = Type[T_Schema]
68-
6965
# Whitelist of dataclass members that will be copied to generated schema.
7066
MEMBERS_WHITELIST: Set[str] = {"Meta"}
7167

@@ -165,12 +161,15 @@ def decorator(clazz: Type[_U]) -> Type[_U]:
165161
return decorator(_cls) if _cls else decorator
166162

167163

168-
@typing.overload
164+
T_Schema = TypeVar("T_Schema", bound=marshmallow.Schema)
165+
T_SchemaType = Type[T_Schema]
166+
167+
@overload
169168
def class_schema(clazz: type, base_schema: None = None) -> Type[marshmallow.Schema]:
170169
...
171170

172171

173-
@typing.overload
172+
@overload
174173
def class_schema(clazz: type, base_schema: T_SchemaType) -> T_SchemaType:
175174
...
176175

@@ -469,12 +468,17 @@ def _base_schema(
469468
Base schema factory that creates a schema for `clazz` derived either from `base_schema`
470469
or `BaseSchema`
471470
"""
471+
472472
# Remove `type: ignore` when mypy handles dynamic base classes
473473
# https://github.com/python/mypy/issues/2813
474474
class BaseSchema(base_schema or marshmallow.Schema): # type: ignore
475-
@marshmallow.post_load
476-
def make_data_class(self, data, **_):
477-
return clazz(**data)
475+
def load(self, data: Mapping, *, many: bool = None, **kwargs):
476+
all_loaded = super().load(data, many=many, **kwargs)
477+
many = self.many if many is None else bool(many)
478+
if many:
479+
return [clazz(**loaded) for loaded in all_loaded]
480+
else:
481+
return clazz(**all_loaded)
478482

479483
return BaseSchema
480484

tests/test_post_load.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import unittest
2+
3+
import marshmallow
4+
5+
import marshmallow_dataclass
6+
7+
8+
# Regression test for https://github.com/lovasoa/marshmallow_dataclass/issues/75
9+
class TestPostLoad(unittest.TestCase):
10+
@marshmallow_dataclass.dataclass
11+
class Named:
12+
first_name: str
13+
last_name: str
14+
15+
@marshmallow.post_load
16+
def a(self, data, **_kwargs):
17+
data["first_name"] = data["first_name"].capitalize()
18+
return data
19+
20+
@marshmallow.post_load
21+
def z(self, data, **_kwargs):
22+
data["last_name"] = data["last_name"].capitalize()
23+
return data
24+
25+
def test_post_load_method_naming_does_not_affect_data(self):
26+
actual = self.Named.Schema().load(
27+
{"first_name": "matt", "last_name": "groening"}
28+
)
29+
expected = self.Named(first_name="Matt", last_name="Groening")
30+
self.assertEqual(actual, expected)
31+
32+
def test_load_many(self):
33+
actual = self.Named.Schema().load(
34+
[
35+
{"first_name": "matt", "last_name": "groening"},
36+
{"first_name": "bart", "last_name": "simpson"},
37+
],
38+
many=True,
39+
)
40+
expected = [
41+
self.Named(first_name="Matt", last_name="Groening"),
42+
self.Named(first_name="Bart", last_name="Simpson"),
43+
]
44+
self.assertEqual(actual, expected)

0 commit comments

Comments
 (0)