Skip to content

Commit f619cc9

Browse files
authored
Merge pull request #13 from harp-tech/gl-dev
Refactor modules for type checking
2 parents 155188e + de09c85 commit f619cc9

6 files changed

Lines changed: 92 additions & 74 deletions

File tree

harp/io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def read(
8484
time = micros * _SECONDS_PER_TICK + seconds
8585
payloadtype = payloadtype & ~0x10
8686
if epoch is not None:
87-
time = epoch + pd.to_timedelta(time, "s")
87+
time = epoch + pd.to_timedelta(time, "s") # type: ignore
8888
index = pd.Series(time)
8989
index.name = "time"
9090

@@ -111,6 +111,6 @@ def read(
111111
msgtype = np.ndarray(
112112
nrows, dtype=np.uint8, buffer=data, offset=0, strides=stride
113113
)
114-
msgtype = pd.Categorical.from_codes(msgtype, categories=_messagetypes)
114+
msgtype = pd.Categorical.from_codes(msgtype, categories=_messagetypes) # type: ignore
115115
result["type"] = msgtype
116116
return result

harp/model.py

Lines changed: 63 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,44 @@
55
from __future__ import annotations
66

77
from enum import Enum
8-
from typing import Any, Dict, List, Optional, Union
8+
from typing import Dict, List, Optional, Union
99
from typing_extensions import Annotated
1010

11-
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, RootModel, conint, field_serializer
11+
from pydantic import (
12+
BaseModel,
13+
BeforeValidator,
14+
ConfigDict,
15+
Field,
16+
RootModel,
17+
field_serializer,
18+
)
1219

1320

1421
class PayloadType(str, Enum):
15-
U8 = 'uint8'
16-
S8 = 'int8'
17-
U16 = 'uint16'
18-
S16 = 'int16'
19-
U32 = 'uint32'
20-
S32 = 'int32'
21-
U64 = 'uint64'
22-
S64 = 'int64'
23-
Float = 'float32'
22+
U8 = "uint8"
23+
S8 = "int8"
24+
U16 = "uint16"
25+
S16 = "int16"
26+
U32 = "uint32"
27+
S32 = "int32"
28+
U64 = "uint64"
29+
S64 = "int64"
30+
Float = "float32"
2431

2532

2633
class Access(Enum):
27-
Read = 'Read'
28-
Write = 'Write'
29-
Event = 'Event'
34+
Read = "Read"
35+
Write = "Write"
36+
Event = "Event"
3037

3138

3239
class MaskValueItem(BaseModel):
3340
model_config = ConfigDict(
34-
extra='forbid',
41+
extra="forbid",
3542
)
36-
value: int = Field(..., description='Specifies the numerical mask value.')
43+
value: int = Field(..., description="Specifies the numerical mask value.")
3744
description: Optional[str] = Field(
38-
None, description='Specifies a summary description of the mask value function.'
45+
None, description="Specifies a summary description of the mask value function."
3946
)
4047

4148
def __int__(self):
@@ -48,70 +55,70 @@ class MaskValue(RootModel[Union[int, MaskValueItem]]):
4855

4956
class BitMask(BaseModel):
5057
description: Optional[str] = Field(
51-
None, description='Specifies a summary description of the bit mask function.'
58+
None, description="Specifies a summary description of the bit mask function."
5259
)
5360
bits: Dict[str, MaskValue]
5461

5562

5663
class GroupMask(BaseModel):
5764
description: Optional[str] = Field(
58-
None, description='Specifies a summary description of the group mask function.'
65+
None, description="Specifies a summary description of the group mask function."
5966
)
6067
values: Dict[str, MaskValue]
6168

6269

6370
class MaskType(RootModel[str]):
6471
root: str = Field(
6572
...,
66-
description='Specifies the name of the bit mask or group mask used to represent the payload value.',
73+
description="Specifies the name of the bit mask or group mask used to represent the payload value.",
6774
)
6875

6976

7077
class InterfaceType(RootModel[str]):
7178
root: str = Field(
7279
...,
73-
description='Specifies the name of the type used to represent the payload value in the high-level interface.',
80+
description="Specifies the name of the type used to represent the payload value in the high-level interface.",
7481
)
7582

7683

7784
class Converter(Enum):
78-
None_ = 'None'
79-
Payload = 'Payload'
80-
RawPayload = 'RawPayload'
85+
None_ = "None"
86+
Payload = "Payload"
87+
RawPayload = "RawPayload"
8188

8289

8390
class MinValue(RootModel[float]):
8491
root: float = Field(
8592
...,
86-
description='Specifies the minimum allowable value for the payload or payload member.',
93+
description="Specifies the minimum allowable value for the payload or payload member.",
8794
)
8895

8996

9097
class MaxValue(RootModel[float]):
9198
root: float = Field(
9299
...,
93-
description='Specifies the maximum allowable value for the payload or payload member.',
100+
description="Specifies the maximum allowable value for the payload or payload member.",
94101
)
95102

96103

97104
class DefaultValue(RootModel[float]):
98105
root: float = Field(
99106
...,
100-
description='Specifies the default value for the payload or payload member.',
107+
description="Specifies the default value for the payload or payload member.",
101108
)
102109

103110

104111
class PayloadMember(BaseModel):
105112
mask: Optional[int] = Field(
106113
None,
107-
description='Specifies the mask used to read and write this payload member.',
114+
description="Specifies the mask used to read and write this payload member.",
108115
)
109116
offset: Optional[int] = Field(
110117
None,
111-
description='Specifies the payload array offset where this payload member is stored.',
118+
description="Specifies the payload array offset where this payload member is stored.",
112119
)
113120
description: Optional[str] = Field(
114-
None, description='Specifies a summary description of the payload member.'
121+
None, description="Specifies a summary description of the payload member."
115122
)
116123
minValue: Optional[MinValue] = None
117124
maxValue: Optional[MaxValue] = None
@@ -122,68 +129,74 @@ class PayloadMember(BaseModel):
122129

123130

124131
class Visibility(Enum):
125-
public = 'public'
126-
private = 'private'
132+
public = "public"
133+
private = "private"
127134

128135

129136
class Register(BaseModel):
130-
address: conint(le=255) = Field(
131-
..., description='Specifies the unique 8-bit address of the register.'
132-
)
137+
address: Annotated[
138+
int,
139+
Field(
140+
le=255, description="Specifies the unique 8-bit address of the register."
141+
),
142+
]
133143
type: Annotated[PayloadType, BeforeValidator(lambda v: PayloadType[v])]
134-
length: Optional[conint(ge=1)] = Field(
135-
default=1, description='Specifies the length of the register payload.'
136-
)
144+
length: Annotated[
145+
Optional[int],
146+
Field(
147+
ge=1, default=1, description="Specifies the length of the register payload."
148+
),
149+
]
137150
access: Union[Access, List[Access]] = Field(
138-
..., description='Specifies the expected use of the register.'
151+
..., description="Specifies the expected use of the register."
139152
)
140153
description: Optional[str] = Field(
141-
None, description='Specifies a summary description of the register function.'
154+
None, description="Specifies a summary description of the register function."
142155
)
143156
minValue: Optional[MinValue] = None
144157
maxValue: Optional[MaxValue] = None
145158
defaultValue: Optional[DefaultValue] = None
146159
maskType: Optional[MaskType] = None
147160
visibility: Optional[Visibility] = Field(
148161
None,
149-
description='Specifies whether the register function is exposed in the high-level interface.',
162+
description="Specifies whether the register function is exposed in the high-level interface.",
150163
)
151164
volatile: Optional[bool] = Field(
152165
None,
153-
description='Specifies whether register values can be saved in non-volatile memory.',
166+
description="Specifies whether register values can be saved in non-volatile memory.",
154167
)
155168
payloadSpec: Optional[Dict[str, PayloadMember]] = None
156169
interfaceType: Optional[InterfaceType] = None
157170
converter: Optional[Converter] = None
158171

159-
@field_serializer('type')
172+
@field_serializer("type")
160173
def _serialize_type(self, type: PayloadType):
161174
return type.name
162175

163176

164177
class Registers(BaseModel):
165178
registers: Dict[str, Register] = Field(
166179
...,
167-
description='Specifies the collection of registers implementing the device function.',
180+
description="Specifies the collection of registers implementing the device function.",
168181
)
169182
bitMasks: Optional[Dict[str, BitMask]] = Field(
170183
None,
171-
description='Specifies the collection of masks available to be used with the different registers.',
184+
description="Specifies the collection of masks available to be used with the different registers.",
172185
)
173186
groupMasks: Optional[Dict[str, GroupMask]] = Field(
174187
None,
175-
description='Specifies the collection of group masks available to be used with the different registers.',
188+
description="Specifies the collection of group masks available to be used with the different registers.",
176189
)
177190

178191

179192
class Model(Registers):
180-
device: str = Field(..., description='Specifies the name of the device.')
193+
device: str = Field(..., description="Specifies the name of the device.")
181194
whoAmI: int = Field(
182-
..., description='Specifies the unique identifier for this device type.'
195+
..., description="Specifies the unique identifier for this device type."
183196
)
184197
firmwareVersion: str = Field(
185-
..., description='Specifies the semantic version of the device firmware.'
198+
..., description="Specifies the semantic version of the device firmware."
186199
)
187200
hardwareTargets: str = Field(
188-
..., description='Specifies the semantic version of the device hardware.'
201+
..., description="Specifies the semantic version of the device hardware."
189202
)

harp/reader.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22
from math import log2
33
from os import PathLike
44
from functools import partial
5+
from numpy import dtype
56
from pandas import DataFrame, Series
67
from typing import Any, BinaryIO, Iterable, Callable, Optional, Union
78
from pandas._typing import Axes
89
from harp.model import BitMask, GroupMask, Model, PayloadMember, Register
910
from harp.io import read
11+
from harp.schema import read_schema
1012

1113
_camel_to_snake_regex = re.compile(r"(?<!^)(?=[A-Z])")
1214

1315

1416
class RegisterReader:
1517
register: Register
16-
read: Callable[[str], DataFrame]
18+
read: Callable[[Union[str, bytes, PathLike[Any], BinaryIO]], DataFrame]
1719

1820
def __init__(
1921
self,
@@ -89,14 +91,14 @@ def _create_payloadmember_parser(device: Model, member: PayloadMember):
8991
if offset is None:
9092
offset = 0
9193

92-
shift = None
94+
shift = 0
9395
if member.mask is not None:
9496
shift = _mask_shift(member.mask)
9597

9698
lookup = None
9799
if member.maskType is not None:
98100
key = member.maskType.root
99-
groupMask = device.groupMasks.get(key)
101+
groupMask = None if device.groupMasks is None else device.groupMasks.get(key)
100102
if groupMask is not None:
101103
lookup = _create_groupmask_lookup(groupMask)
102104

@@ -109,7 +111,7 @@ def parser(df: DataFrame):
109111
if member.mask is not None:
110112
series = series & member.mask
111113
if shift > 0:
112-
series = Series(series.values >> shift, series.index)
114+
series = Series(series.values >> shift, series.index) # type: ignore
113115
if is_boolean:
114116
series = series != 0
115117
elif lookup is not None:
@@ -126,7 +128,7 @@ def reader(
126128
data = read(
127129
file,
128130
address=register.address,
129-
dtype=register.type,
131+
dtype=dtype(register.type),
130132
length=register.length,
131133
columns=columns,
132134
)
@@ -141,13 +143,13 @@ def _create_register_parser(device: Model, name: str):
141143

142144
if register.maskType is not None:
143145
key = register.maskType.root
144-
bitMask = device.bitMasks.get(key)
146+
bitMask = None if device.bitMasks is None else device.bitMasks.get(key)
145147
if bitMask is not None:
146148
parser = _create_bitmask_parser(bitMask)
147149
reader = _compose(parser, reader)
148150
return RegisterReader(register, reader)
149151

150-
groupMask = device.groupMasks.get(key)
152+
groupMask = None if device.groupMasks is None else device.groupMasks.get(key)
151153
if groupMask is not None:
152154
parser = _create_groupmask_parser(name, groupMask)
153155
reader = _compose(parser, reader)
@@ -170,7 +172,8 @@ def parser(df: DataFrame):
170172
return RegisterReader(register, reader)
171173

172174

173-
def create_reader(device: Model):
175+
def create_reader(file: Union[str, PathLike], include_common_registers: bool = True):
176+
device = read_schema(file, include_common_registers)
174177
reg_readers = {
175178
name: _create_register_parser(device, name) for name in device.registers.keys()
176179
}

harp/schema.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,26 @@
88

99

1010
def _read_common_registers(file: Union[str, PathLike, TextIO]) -> Registers:
11-
try:
11+
if isinstance(file, TextIO):
12+
return parse_yaml_raw_as(Registers, file.read())
13+
else:
1214
with open(file) as fileIO:
1315
return _read_common_registers(fileIO)
14-
except TypeError:
15-
return parse_yaml_raw_as(Registers, file.read())
1616

1717

1818
def read_schema(
1919
file: Union[str, PathLike, TextIO], include_common_registers: bool = True
2020
) -> Model:
21-
try:
22-
with open(file) as fileIO:
23-
return read_schema(fileIO, include_common_registers)
24-
except TypeError:
21+
if isinstance(file, TextIO):
2522
schema = parse_yaml_raw_as(Model, file.read())
2623
if not "WhoAmI" in schema.registers and include_common_registers:
2724
common = _read_common_registers(_common_yaml_path)
2825
schema.registers = dict(common.registers, **schema.registers)
29-
schema.bitMasks = dict(common.bitMasks, **schema.bitMasks)
30-
schema.groupMasks = dict(common.groupMasks, **schema.groupMasks)
26+
if schema.bitMasks is not None and common.bitMasks is not None:
27+
schema.bitMasks = dict(common.bitMasks, **schema.bitMasks)
28+
if schema.groupMasks is not None and common.groupMasks is not None:
29+
schema.groupMasks = dict(common.groupMasks, **schema.groupMasks)
3130
return schema
31+
else:
32+
with open(file) as fileIO:
33+
return read_schema(fileIO, include_common_registers)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ extend-exclude = '''
6565
(
6666
^/LICENSE
6767
^/README.md
68-
| harp/model.py
6968
| reflex-generator
7069
)
7170
'''

0 commit comments

Comments
 (0)