-
-
Notifications
You must be signed in to change notification settings - Fork 843
Expand file tree
/
Copy pathtest_generics.py
More file actions
122 lines (92 loc) · 3.85 KB
/
test_generics.py
File metadata and controls
122 lines (92 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from enum import Enum
from typing import Generic, Literal, TypeVar
import pytest
from sqlalchemy import create_engine
from sqlmodel import Field, Session, SQLModel, select
from typing_extensions import assert_type
def test_generic_type_with_bound(clear_sqlmodel) -> None:
TagT = TypeVar("TagT", bound=int)
class HeroFields(SQLModel, Generic[TagT]):
tag: TagT
class Hero(HeroFields[int], table=True):
id: int | None = Field(default=None, primary_key=True)
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
tag_number = 67
hero = Hero(tag=tag_number)
session.add(hero)
hero = session.exec(select(Hero).where(Hero.tag == tag_number)).first()
assert hero is not None
assert hero.tag == tag_number
def test_generic_type_with_constraints(clear_sqlmodel) -> None:
TagT = TypeVar("TagT", int, None)
class HeroFields(SQLModel, Generic[TagT]):
tag: TagT
class Hero(HeroFields[int], table=True):
id: int | None = Field(default=None, primary_key=True)
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
tag_number = 67
hero = Hero(tag=tag_number)
session.add(hero)
hero = session.exec(select(Hero).where(Hero.tag == tag_number)).first()
assert hero is not None
assert hero.tag == tag_number
def test_generic_type_with_multiple_type_constraints_raises_error(
clear_sqlmodel,
) -> None:
with pytest.raises(ValueError):
TagT = TypeVar("TagT", int, str)
class HeroFields(SQLModel, Generic[TagT]):
tag: TagT
class Hero(HeroFields[int], table=True):
id: int | None = Field(default=None, primary_key=True)
def test_discriminated_union_with_generics(clear_sqlmodel) -> None:
AmountRefundedT = TypeVar("AmountRefundedT", bound=int | None)
RejectionMessageT = TypeVar("RejectionMessageT", bound=str | None)
class RefundStatus(str, Enum):
ACCEPTED = "ACCEPTED"
REJECTED = "REJECTED"
DiscriminantT = TypeVar("DiscriminantT", bound=RefundStatus)
class RefundRequestFields(
SQLModel, Generic[AmountRefundedT, RejectionMessageT, DiscriminantT]
):
item_name: str
amount_refunded: AmountRefundedT
rejection_message: RejectionMessageT
status: DiscriminantT
class RefundRequest(
RefundRequestFields[int | None, str | None, RefundStatus], table=True
):
id: int | None = Field(default=None, primary_key=True)
status: RefundStatus
class AcceptedRequest(RefundRequestFields[int, None, RefundStatus.ACCEPTED]):
amount_refunded: int
rejection_message: None = None
status: Literal[RefundStatus.ACCEPTED] = RefundStatus.ACCEPTED
class RejectedRequest(RefundRequestFields[None, str, RefundStatus.REJECTED]):
rejection_message: str
amount_refunded: None = None
status: Literal[RefundStatus.REJECTED] = RefundStatus.REJECTED
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
c = RejectedRequest(
item_name="EmptyJuice",
rejection_message="This item cannot be refunded because it has been emptied",
)
session.add(RefundRequest.model_validate(c.model_dump()))
requests = session.exec(
select(RefundRequest).where(
RefundRequest.status == RefundStatus.REJECTED,
)
).all()
rejected_requests = [
RejectedRequest.model_validate(request.model_dump())
for request in requests
if request.status == RefundStatus.REJECTED
]
assert_type(rejected_requests, list[RejectedRequest])
assert len(rejected_requests) == 1