Skip to content

Commit a90a006

Browse files
committed
API: normalize naive datetimes to UTC in validators
1 parent d337209 commit a90a006

5 files changed

Lines changed: 66 additions & 18 deletions

File tree

dp3/api/internal/models.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from datetime import datetime
21
from typing import Annotated, Any, Literal, Optional, Union
32

43
from pydantic import BaseModel, Field, TypeAdapter, create_model, model_validator
54

65
from dp3.api.internal.config import MODEL_SPEC
76
from dp3.api.internal.helpers import api_to_dp3_datapoint
8-
from dp3.common.types import T2Datetime
7+
from dp3.common.types import AwareDatetime, T2Datetime
98

109

1110
class DataPoint(BaseModel):
@@ -27,7 +26,7 @@ class DataPoint(BaseModel):
2726
id: Any
2827
attr: str
2928
v: Any
30-
t1: Optional[datetime] = None
29+
t1: Optional[AwareDatetime] = None
3130
t2: Optional[T2Datetime] = Field(None, validate_default=True)
3231
c: Annotated[float, Field(ge=0.0, le=1.0)] = 1.0
3332
src: Optional[str] = None

dp3/api/routers/entity.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from dp3.api.internal.response_models import ErrorResponse, RequestValidationError, SuccessResponse
2121
from dp3.common.attrspec import AttrType
2222
from dp3.common.task import DataPointTask, task_context
23-
from dp3.common.types import UTC
23+
from dp3.common.types import UTC, AwareDatetime
2424
from dp3.database.database import DatabaseError
2525

2626

@@ -43,7 +43,7 @@ async def parse_eid(etype: str, eid: str):
4343

4444

4545
def get_eid_master_record_handler(
46-
e: EntityId, date_from: Optional[datetime] = None, date_to: Optional[datetime] = None
46+
e: EntityId, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None
4747
):
4848
"""Handler for getting master record of EID"""
4949
# TODO: This is probably not the most efficient way. Maybe gather only
@@ -67,7 +67,7 @@ def get_eid_master_record_handler(
6767

6868

6969
def get_eid_snapshots_handler(
70-
e: EntityId, date_from: Optional[datetime] = None, date_to: Optional[datetime] = None
70+
e: EntityId, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None
7171
) -> list[dict[str, Any]]:
7272
"""Handler for getting snapshots of EID"""
7373
snapshots = list(DB.snapshots.get_by_eid(e.type, e.id, t1=date_from, t2=date_to))
@@ -275,7 +275,7 @@ async def count_entity_type_eids(
275275

276276
@router.get("/{etype}/{eid}")
277277
async def get_eid_data(
278-
e: ParsedEid, date_from: Optional[datetime] = None, date_to: Optional[datetime] = None
278+
e: ParsedEid, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None
279279
) -> EntityEidData:
280280
"""Get data of `etype`'s `eid`.
281281
@@ -295,15 +295,15 @@ async def get_eid_data(
295295

296296
@router.get("/{etype}/{eid}/master")
297297
async def get_eid_master_record(
298-
e: ParsedEid, date_from: Optional[datetime] = None, date_to: Optional[datetime] = None
298+
e: ParsedEid, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None
299299
) -> EntityEidMasterRecord:
300300
"""Get master record of `etype`'s `eid`."""
301301
return get_eid_master_record_handler(e, date_from, date_to)
302302

303303

304304
@router.get("/{etype}/{eid}/snapshots")
305305
async def get_eid_snapshots(
306-
e: ParsedEid, date_from: Optional[datetime] = None, date_to: Optional[datetime] = None
306+
e: ParsedEid, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None
307307
) -> EntityEidSnapshots:
308308
"""Get snapshots of `etype`'s `eid`."""
309309
return get_eid_snapshots_handler(e, date_from, date_to)
@@ -313,8 +313,8 @@ async def get_eid_snapshots(
313313
async def get_eid_attr_value(
314314
e: ParsedEid,
315315
attr: str,
316-
date_from: Optional[datetime] = None,
317-
date_to: Optional[datetime] = None,
316+
date_from: Optional[AwareDatetime] = None,
317+
date_to: Optional[AwareDatetime] = None,
318318
) -> EntityEidAttrValueOrHistory:
319319
"""Get attribute value
320320
@@ -394,7 +394,7 @@ async def get_distinct_attribute_values(etype: str, attr: str) -> dict[JsonVal,
394394

395395

396396
@router.post("/{etype}/{eid}/ttl")
397-
async def extend_eid_ttls(e: ParsedEid, body: dict[str, datetime]) -> SuccessResponse:
397+
async def extend_eid_ttls(e: ParsedEid, body: dict[str, AwareDatetime]) -> SuccessResponse:
398398
"""Extend TTLs of the specified entity"""
399399
# Construct task
400400
with task_context(MODEL_SPEC):

dp3/common/datapoint.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from datetime import datetime
21
from ipaddress import IPv4Address, IPv6Address
32
from typing import Annotated, Any, Optional, Union
43

54
from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer
65

76
from dp3.common.mac_address import MACAddress
8-
from dp3.common.types import T2Datetime
7+
from dp3.common.types import AwareDatetime, T2Datetime
98

109

1110
def to_json_friendly(v):
@@ -60,7 +59,7 @@ class DataPointObservationsBase(DataPointBase):
6059
Contains single raw data value received on API for observations attribute.
6160
"""
6261

63-
t1: datetime
62+
t1: AwareDatetime
6463
t2: T2Datetime = Field(None, validate_default=True)
6564
c: Annotated[float, Field(ge=0.0, le=1.0)] = 1.0
6665

@@ -71,7 +70,7 @@ class DataPointTimeseriesBase(DataPointBase):
7170
Contains single raw data value received on API for observations attribute.
7271
"""
7372

74-
t1: datetime
73+
t1: AwareDatetime
7574
t2: T2Datetime = Field(None, validate_default=True)
7675

7776

dp3/common/types.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from datetime import datetime, timedelta, timezone
22
from ipaddress import IPv4Address, IPv6Address
33
from json import JSONEncoder
4-
from typing import Annotated, Any, Union
4+
from typing import Annotated, Any, Optional, Union
55

66
from event_count_logger import DummyEventGroup, EventGroup
77
from pydantic import AfterValidator, BeforeValidator
@@ -24,6 +24,18 @@ def parse_timedelta_or_passthrough(v):
2424
ParsedTimedelta = Annotated[timedelta, BeforeValidator(parse_timedelta_or_passthrough)]
2525

2626

27+
def ensure_timezone_aware(v: Optional[datetime]):
28+
"""Ensure datetime is timezone-aware by defaulting to UTC."""
29+
if v is None:
30+
return v
31+
if v.tzinfo is None:
32+
return v.replace(tzinfo=UTC)
33+
return v
34+
35+
36+
AwareDatetime = Annotated[datetime, AfterValidator(ensure_timezone_aware)]
37+
38+
2739
def t2_implicity_t1(v, info: FieldValidationInfo):
2840
"""If t2 is not specified, it is set to t1."""
2941
v = v or info.data.get("t1")
@@ -37,7 +49,11 @@ def t2_after_t1(v, info: FieldValidationInfo):
3749
return v
3850

3951

40-
T2Datetime = Annotated[datetime, BeforeValidator(t2_implicity_t1), AfterValidator(t2_after_t1)]
52+
T2Datetime = Annotated[
53+
AwareDatetime,
54+
BeforeValidator(t2_implicity_t1),
55+
AfterValidator(t2_after_t1),
56+
]
4157

4258
EventGroupType = Union[EventGroup, DummyEventGroup]
4359

tests/test_common/test_types.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import unittest
2+
from datetime import datetime, timedelta, timezone
3+
4+
from pydantic import BaseModel, Field
5+
6+
from dp3.common.types import AwareDatetime, T2Datetime
7+
8+
9+
class _AwareModel(BaseModel):
10+
dt: AwareDatetime
11+
12+
13+
class _T2Model(BaseModel):
14+
t1: AwareDatetime
15+
t2: T2Datetime = Field(None, validate_default=True)
16+
17+
18+
class TestAwareDatetime(unittest.TestCase):
19+
def test_naive_datetime_defaults_to_utc(self):
20+
model = _AwareModel(dt="2024-01-01T10:00:00")
21+
self.assertEqual(model.dt.tzinfo, timezone.utc)
22+
23+
def test_existing_timezone_is_preserved(self):
24+
cest_timezone = timezone(timedelta(hours=2), "CEST")
25+
aware = datetime(2024, 1, 1, 10, 0, tzinfo=cest_timezone)
26+
model = _AwareModel(dt=aware)
27+
self.assertEqual(model.dt.tzinfo, cest_timezone)
28+
29+
def test_t2_datetime_inherits_timezone_when_missing(self):
30+
model = _T2Model(t1="2024-01-01T00:00:00")
31+
self.assertIsNotNone(model.t2)
32+
self.assertEqual(model.t1.tzinfo, timezone.utc)
33+
self.assertEqual(model.t2.tzinfo, timezone.utc)
34+
self.assertEqual(model.t2, model.t1)

0 commit comments

Comments
 (0)