Skip to content

Commit de09c85

Browse files
committed
Refactor modules for typecheck support
1 parent e30500d commit de09c85

5 files changed

Lines changed: 42 additions & 31 deletions

File tree

harp/io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def read(
8484
time = micros * _SECONDS_PER_TICK + seconds
8585
payloadtype = payloadtype & ~0x10
8686
if epoch is not None:
87-
time = epoch + pd.to_timedelta(time, "s")
87+
time = epoch + pd.to_timedelta(time, "s") # type: ignore
8888
index = pd.Series(time)
8989
index.name = "time"
9090

@@ -111,6 +111,6 @@ def read(
111111
msgtype = np.ndarray(
112112
nrows, dtype=np.uint8, buffer=data, offset=0, strides=stride
113113
)
114-
msgtype = pd.Categorical.from_codes(msgtype, categories=_messagetypes)
114+
msgtype = pd.Categorical.from_codes(msgtype, categories=_messagetypes) # type: ignore
115115
result["type"] = msgtype
116116
return result

harp/model.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from __future__ import annotations
66

77
from enum import Enum
8-
from typing import Any, Dict, List, Optional, Union
8+
from typing import Dict, List, Optional, Union
99
from typing_extensions import Annotated
1010

1111
from pydantic import (
@@ -14,7 +14,6 @@
1414
ConfigDict,
1515
Field,
1616
RootModel,
17-
conint,
1817
field_serializer,
1918
)
2019

@@ -135,13 +134,19 @@ class Visibility(Enum):
135134

136135

137136
class Register(BaseModel):
138-
address: conint(le=255) = Field(
139-
..., description="Specifies the unique 8-bit address of the register."
140-
)
137+
address: Annotated[
138+
int,
139+
Field(
140+
le=255, description="Specifies the unique 8-bit address of the register."
141+
),
142+
]
141143
type: Annotated[PayloadType, BeforeValidator(lambda v: PayloadType[v])]
142-
length: Optional[conint(ge=1)] = Field(
143-
default=1, description="Specifies the length of the register payload."
144-
)
144+
length: Annotated[
145+
Optional[int],
146+
Field(
147+
ge=1, default=1, description="Specifies the length of the register payload."
148+
),
149+
]
145150
access: Union[Access, List[Access]] = Field(
146151
..., description="Specifies the expected use of the register."
147152
)

harp/reader.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22
from math import log2
33
from os import PathLike
44
from functools import partial
5+
from numpy import dtype
56
from pandas import DataFrame, Series
67
from typing import Any, BinaryIO, Iterable, Callable, Optional, Union
78
from pandas._typing import Axes
89
from harp.model import BitMask, GroupMask, Model, PayloadMember, Register
910
from harp.io import read
11+
from harp.schema import read_schema
1012

1113
_camel_to_snake_regex = re.compile(r"(?<!^)(?=[A-Z])")
1214

1315

1416
class RegisterReader:
1517
register: Register
16-
read: Callable[[str], DataFrame]
18+
read: Callable[[Union[str, bytes, PathLike[Any], BinaryIO]], DataFrame]
1719

1820
def __init__(
1921
self,
@@ -89,14 +91,14 @@ def _create_payloadmember_parser(device: Model, member: PayloadMember):
8991
if offset is None:
9092
offset = 0
9193

92-
shift = None
94+
shift = 0
9395
if member.mask is not None:
9496
shift = _mask_shift(member.mask)
9597

9698
lookup = None
9799
if member.maskType is not None:
98100
key = member.maskType.root
99-
groupMask = device.groupMasks.get(key)
101+
groupMask = None if device.groupMasks is None else device.groupMasks.get(key)
100102
if groupMask is not None:
101103
lookup = _create_groupmask_lookup(groupMask)
102104

@@ -109,7 +111,7 @@ def parser(df: DataFrame):
109111
if member.mask is not None:
110112
series = series & member.mask
111113
if shift > 0:
112-
series = Series(series.values >> shift, series.index)
114+
series = Series(series.values >> shift, series.index) # type: ignore
113115
if is_boolean:
114116
series = series != 0
115117
elif lookup is not None:
@@ -126,7 +128,7 @@ def reader(
126128
data = read(
127129
file,
128130
address=register.address,
129-
dtype=register.type,
131+
dtype=dtype(register.type),
130132
length=register.length,
131133
columns=columns,
132134
)
@@ -141,13 +143,13 @@ def _create_register_parser(device: Model, name: str):
141143

142144
if register.maskType is not None:
143145
key = register.maskType.root
144-
bitMask = device.bitMasks.get(key)
146+
bitMask = None if device.bitMasks is None else device.bitMasks.get(key)
145147
if bitMask is not None:
146148
parser = _create_bitmask_parser(bitMask)
147149
reader = _compose(parser, reader)
148150
return RegisterReader(register, reader)
149151

150-
groupMask = device.groupMasks.get(key)
152+
groupMask = None if device.groupMasks is None else device.groupMasks.get(key)
151153
if groupMask is not None:
152154
parser = _create_groupmask_parser(name, groupMask)
153155
reader = _compose(parser, reader)
@@ -170,7 +172,8 @@ def parser(df: DataFrame):
170172
return RegisterReader(register, reader)
171173

172174

173-
def create_reader(device: Model):
175+
def create_reader(file: Union[str, PathLike], include_common_registers: bool = True):
176+
device = read_schema(file, include_common_registers)
174177
reg_readers = {
175178
name: _create_register_parser(device, name) for name in device.registers.keys()
176179
}

harp/schema.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,26 @@
88

99

1010
def _read_common_registers(file: Union[str, PathLike, TextIO]) -> Registers:
11-
try:
11+
if isinstance(file, TextIO):
12+
return parse_yaml_raw_as(Registers, file.read())
13+
else:
1214
with open(file) as fileIO:
1315
return _read_common_registers(fileIO)
14-
except TypeError:
15-
return parse_yaml_raw_as(Registers, file.read())
1616

1717

1818
def read_schema(
1919
file: Union[str, PathLike, TextIO], include_common_registers: bool = True
2020
) -> Model:
21-
try:
22-
with open(file) as fileIO:
23-
return read_schema(fileIO, include_common_registers)
24-
except TypeError:
21+
if isinstance(file, TextIO):
2522
schema = parse_yaml_raw_as(Model, file.read())
2623
if not "WhoAmI" in schema.registers and include_common_registers:
2724
common = _read_common_registers(_common_yaml_path)
2825
schema.registers = dict(common.registers, **schema.registers)
29-
schema.bitMasks = dict(common.bitMasks, **schema.bitMasks)
30-
schema.groupMasks = dict(common.groupMasks, **schema.groupMasks)
26+
if schema.bitMasks is not None and common.bitMasks is not None:
27+
schema.bitMasks = dict(common.bitMasks, **schema.bitMasks)
28+
if schema.groupMasks is not None and common.groupMasks is not None:
29+
schema.groupMasks = dict(common.groupMasks, **schema.groupMasks)
3130
return schema
31+
else:
32+
with open(file) as fileIO:
33+
return read_schema(fileIO, include_common_registers)

tests/test_io.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import numpy as np
3-
from typing import Iterable, Optional, Type
3+
from os import PathLike
4+
from typing import Iterable, Optional, Type, Union
45
from contextlib import nullcontext
56
from pytest import mark
67
from pathlib import Path
@@ -12,7 +13,7 @@
1213

1314
@dataclass
1415
class DataFileParam:
15-
path: str
16+
path: Union[str, PathLike]
1617
expected_rows: int
1718
expected_cols: Optional[Iterable[str]] = None
1819
expected_address: Optional[int] = None
@@ -36,7 +37,7 @@ def __post_init__(self):
3637
DataFileParam(
3738
path="data/device_0.bin",
3839
expected_rows=1,
39-
expected_dtype="uint8", # actual dtype is uint16
40+
expected_dtype=np.dtype("uint8"), # actual dtype is uint16
4041
expected_error=ValueError,
4142
),
4243
DataFileParam(
@@ -55,7 +56,7 @@ def __post_init__(self):
5556
@mark.parametrize("dataFile", testdata)
5657
def test_read(dataFile: DataFileParam):
5758
context = pytest.raises if dataFile.expected_error is not None else nullcontext
58-
with context(dataFile.expected_error):
59+
with context(dataFile.expected_error): # type: ignore
5960
data = read(
6061
dataFile.path,
6162
address=dataFile.expected_address,

0 commit comments

Comments
 (0)