Skip to content

Commit 30f7bfe

Browse files
committed
refactor: add type hints and pydantic validation schemas
1 parent 1862c67 commit 30f7bfe

7 files changed

Lines changed: 181 additions & 117 deletions

File tree

openpaygo/metrics_request.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,66 @@
11
import copy
2-
2+
from typing import Optional, List, Dict, Any, Union
3+
from .models import MetricsDataFormat, MetricsHistoricalDataStep
34
from .metrics_shared import OpenPAYGOMetricsShared
45

56

67
class MetricsRequestHandler(object):
78
def __init__(
8-
self, serial_number, data_format=None, secret_key=None, auth_method=None
9-
):
9+
self, serial_number: str, data_format: Optional[Union[Dict[str, Any], MetricsDataFormat]] = None, secret_key: Optional[str] = None, auth_method: Optional[str] = None
10+
) -> None:
1011
self.secret_key = secret_key
1112
self.auth_method = auth_method
12-
self.request_dict = {
13+
self.request_dict: Dict[str, Any] = {
1314
"serial_number": serial_number,
1415
}
15-
self.data_format = data_format
16+
if data_format is not None:
17+
if isinstance(data_format, dict):
18+
self.data_format: Optional[Dict[str, Any]] = MetricsDataFormat.model_validate(data_format).model_dump(exclude_none=True)
19+
else:
20+
self.data_format: Optional[Dict[str, Any]] = data_format.model_dump(exclude_none=True)
21+
else:
22+
self.data_format: Optional[Dict[str, Any]] = None
23+
1624
if self.data_format:
1725
if self.data_format.get("id"):
18-
self.request_dict["data_format_id"] = data_format.get("id")
26+
self.request_dict["data_format_id"] = self.data_format.get("id")
1927
else:
20-
self.request_dict["data_format"] = data_format
21-
self.data = {}
22-
self.historical_data = {}
28+
self.request_dict["data_format"] = self.data_format
29+
self.data: Dict[str, Any] = {}
30+
self.historical_data: List[Dict[str, Any]] = []
2331

24-
def set_request_count(self, request_count):
32+
def set_request_count(self, request_count: int) -> None:
2533
self.request_dict["request_count"] = request_count
2634

27-
def set_timestamp(self, timestamp):
35+
def set_timestamp(self, timestamp: int) -> None:
2836
self.request_dict["timestamp"] = timestamp
2937

30-
def set_data(self, data):
38+
def set_data(self, data: Dict[str, Any]) -> None:
3139
self.data = data
3240

33-
def set_historical_data(self, historical_data):
34-
if not self.data_format.get("historical_data_interval"):
35-
for time_step in historical_data:
41+
def set_historical_data(self, historical_data: List[Union[Dict[str, Any], MetricsHistoricalDataStep]]) -> None:
42+
validated_historical_data = []
43+
for time_step in historical_data:
44+
if isinstance(time_step, dict):
45+
step = MetricsHistoricalDataStep.model_validate(time_step).model_dump(exclude_none=True, mode='json')
46+
else:
47+
step = time_step.model_dump(exclude_none=True, mode='json')
48+
validated_historical_data.append(step)
49+
50+
if self.data_format and not self.data_format.get("historical_data_interval"):
51+
for time_step in validated_historical_data:
3652
if not time_step.get("timestamp"):
3753
raise ValueError(
3854
"Historical Data objects must have a time stamp if no "
3955
"historical_data_interval is defined."
4056
)
41-
self.historical_data = historical_data
57+
self.historical_data = validated_historical_data
4258

43-
def get_simple_request_payload(self):
59+
def get_simple_request_payload(self) -> str:
4460
payload = self.get_simple_request_dict()
4561
return OpenPAYGOMetricsShared.convert_to_metrics_json(payload)
4662

47-
def get_simple_request_dict(self):
63+
def get_simple_request_dict(self) -> Dict[str, Any]:
4864
simple_request = self.request_dict
4965
simple_request["data"] = self.data
5066
simple_request["historical_data"] = self.historical_data
@@ -57,11 +73,11 @@ def get_simple_request_dict(self):
5773
)
5874
return simple_request
5975

60-
def get_condensed_request_payload(self):
76+
def get_condensed_request_payload(self) -> str:
6177
payload = self.get_condensed_request_dict()
6278
return OpenPAYGOMetricsShared.convert_to_metrics_json(payload)
6379

64-
def get_condensed_request_dict(self):
80+
def get_condensed_request_dict(self) -> Dict[str, Any]:
6581
if not self.data_format:
6682
raise ValueError("No Data Format provided for condensed request")
6783
data_order = self.data_format.get("data_order")

openpaygo/metrics_response.py

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
import copy
22
import json
33
from datetime import datetime, timedelta
4+
from typing import Optional, List, Dict, Any, Union
45

6+
from .models import MetricsDataFormat
57
from .metrics_shared import OpenPAYGOMetricsShared
68

79

810
class MetricsResponseHandler(object):
911
def __init__(
1012
self,
11-
received_metrics,
12-
data_format=None,
13-
secret_key=None,
14-
last_request_count=None,
15-
last_request_timestamp=None,
16-
):
13+
received_metrics: str,
14+
data_format: Optional[Union[Dict[str, Any], MetricsDataFormat]] = None,
15+
secret_key: Optional[str] = None,
16+
last_request_count: Optional[int] = None,
17+
last_request_timestamp: Optional[int] = None,
18+
) -> None:
1719
self.received_metrics = received_metrics
1820
self.request_dict = json.loads(received_metrics)
1921
# We convert the base variable names to simple
@@ -28,38 +30,53 @@ def __init__(
2830
self.timestamp = self.request_dict.get("timestamp")
2931
self.response_dict = {}
3032
self.secret_key = secret_key
31-
self.data_format = data_format
3233
self.last_request_count = last_request_count
3334
self.last_request_timestamp = last_request_timestamp
35+
36+
if data_format is not None:
37+
if isinstance(data_format, dict):
38+
self.data_format: Optional[Dict[str, Any]] = MetricsDataFormat.model_validate(data_format).model_dump(exclude_none=True)
39+
else:
40+
self.data_format: Optional[Dict[str, Any]] = data_format.model_dump(exclude_none=True)
41+
else:
42+
self.data_format: Optional[Dict[str, Any]] = None
43+
3444
if not self.data_format and self.request_dict.get("data_format"):
35-
self.data_format = self.request_dict.get("data_format")
45+
df = self.request_dict.get("data_format")
46+
if isinstance(df, dict):
47+
self.data_format = MetricsDataFormat.model_validate(df).model_dump(exclude_none=True)
48+
else:
49+
self.data_format = df
3650

37-
def get_device_serial(self):
51+
def get_device_serial(self) -> str:
3852
return self.request_dict.get("serial_number")
3953

40-
def get_data_format_id(self):
54+
def get_data_format_id(self) -> Optional[int]:
4155
return self.request_dict.get("data_format_id")
4256

43-
def data_format_available(self):
57+
def data_format_available(self) -> bool:
4458
return self.data_format is not None
4559

4660
def set_device_parameters(
4761
self,
48-
secret_key=None,
49-
data_format=None,
50-
last_request_count=None,
51-
last_request_timestamp=None,
52-
):
62+
secret_key: Optional[str] = None,
63+
data_format: Optional[Union[Dict[str, Any], MetricsDataFormat]] = None,
64+
last_request_count: Optional[int] = None,
65+
last_request_timestamp: Optional[int] = None,
66+
) -> None:
5367
if secret_key:
5468
self.secret_key = secret_key
55-
if data_format:
56-
self.data_format = data_format
69+
if data_format is not None:
70+
if isinstance(data_format, dict):
71+
self.data_format = MetricsDataFormat.model_validate(data_format).model_dump(exclude_none=True)
72+
else:
73+
self.data_format = data_format.model_dump(exclude_none=True)
5774
if last_request_count:
5875
self.last_request_count = last_request_count
5976
if last_request_timestamp:
6077
self.last_request_timestamp = last_request_timestamp
6178

62-
def is_auth_valid(self):
79+
def is_auth_valid(self) -> bool:
6380
auth_string = self.request_dict.get("auth", None)
6481
if not auth_string:
6582
return False
@@ -89,7 +106,7 @@ def is_auth_valid(self):
89106
return True
90107
return False
91108

92-
def get_simple_metrics(self):
109+
def get_simple_metrics(self) -> Dict[str, Any]:
93110
# We start the process by making a copy of the dict to work with
94111
simple_dict = copy.deepcopy(self.request_dict)
95112
simple_dict.pop("auth") if "auth" in simple_dict else None # We remove the auth
@@ -103,34 +120,34 @@ def get_simple_metrics(self):
103120
)
104121
return simple_dict
105122

106-
def get_data_timestamp(self):
123+
def get_data_timestamp(self) -> int:
107124
return self.request_dict.get("data_collection_timestamp", self.timestamp)
108125

109-
def get_request_timestamp(self):
126+
def get_request_timestamp(self) -> Optional[int]:
110127
return self.request_timestamp
111128

112-
def get_request_count(self):
129+
def get_request_count(self) -> Optional[int]:
113130
return self.request_dict.get("request_count")
114131

115-
def get_token_count(self):
132+
def get_token_count(self) -> Optional[int]:
116133
data = self._get_simple_data()
117134
return data.get("token_count")
118135

119-
def expects_token_answer(self):
136+
def expects_token_answer(self) -> bool:
120137
return self.get_token_count() is not None
121138

122-
def add_tokens_to_answer(self, token_list):
139+
def add_tokens_to_answer(self, token_list: List[str]) -> None:
123140
self.response_dict["token_list"] = token_list
124141

125-
def expects_time_answer(self):
142+
def expects_time_answer(self) -> bool:
126143
data = self._get_simple_data()
127144
if data.get("active_until_timestamp_requested", False) or data.get(
128145
"active_seconds_left_requested", False
129146
):
130147
return True
131148
return False
132149

133-
def add_time_to_answer(self, target_datetime):
150+
def add_time_to_answer(self, target_datetime: datetime) -> None:
134151
data = self._get_simple_data()
135152
if data.get("active_until_timestamp_requested", False):
136153
target_timestamp = 0
@@ -150,24 +167,24 @@ def add_time_to_answer(self, target_datetime):
150167
else:
151168
raise ValueError("No time requested")
152169

153-
def add_new_base_url_to_answer(self, new_base_url):
170+
def add_new_base_url_to_answer(self, new_base_url: str) -> None:
154171
self.add_settings_to_answer({"base_url": new_base_url})
155172

156-
def add_settings_to_answer(self, settings_dict):
173+
def add_settings_to_answer(self, settings_dict: Dict[str, Any]) -> None:
157174
if not self.response_dict.get("settings"):
158175
self.response_dict["settings"] = {}
159176
self.response_dict["settings"].update(settings_dict)
160177

161-
def add_extra_data_to_answer(self, extra_data_dict):
178+
def add_extra_data_to_answer(self, extra_data_dict: Dict[str, Any]) -> None:
162179
if not self.response_dict.get("extra_data"):
163180
self.response_dict["extra_data"] = {}
164181
self.response_dict["extra_data"].update(extra_data_dict)
165182

166-
def get_answer_payload(self):
183+
def get_answer_payload(self) -> str:
167184
payload = self.get_answer_dict()
168185
return OpenPAYGOMetricsShared.convert_to_metrics_json(payload)
169186

170-
def get_answer_dict(self):
187+
def get_answer_dict(self) -> Dict[str, Any]:
171188
# If there is not data format, we just return the full response
172189
condensed_answer = copy.deepcopy(self.response_dict)
173190
if self.secret_key:

openpaygo/models.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Optional, List, Any, Dict
2+
from pydantic import BaseModel, ConfigDict
3+
4+
5+
class MetricsDataFormat(BaseModel):
6+
id: Optional[int] = None
7+
data_order: Optional[List[str]] = None
8+
historical_data_order: Optional[List[str]] = None
9+
historical_data_interval: Optional[int] = None
10+
11+
12+
class MetricsHistoricalDataStep(BaseModel):
13+
timestamp: Optional[int] = None
14+
relative_time: Optional[int] = None
15+
model_config = ConfigDict(extra='allow')
16+
17+
18+
class MetricsRequestData(BaseModel):
19+
serial_number: str
20+
data_format_id: Optional[int] = None
21+
data_format: Optional[MetricsDataFormat] = None
22+
data: Optional[Dict[str, Any]] = None
23+
historical_data: Optional[List[Any]] = None
24+
request_count: Optional[int] = None
25+
timestamp: Optional[int] = None
26+
auth: Optional[str] = None
27+
model_config = ConfigDict(extra='allow')

0 commit comments

Comments
 (0)