Skip to content

Commit c232b07

Browse files
committed
type checking: full code coverage
1 parent 3696e68 commit c232b07

16 files changed

Lines changed: 362 additions & 252 deletions

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- parser plugin foundation.
77
- more comprehensive Hyperedge.check_correctness.
88
- check parse correctness.
9+
- type checking: full code coverage.
910

1011
### Changed
1112
- renamed library to hyperbase.

src/hyperbase/hyperedge.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def build_atom(text: str, *parts: str) -> Atom:
140140
return Atom((atom,))
141141

142142

143-
class Hyperedge(tuple):
143+
class Hyperedge(tuple): # type: ignore[type-arg]
144144
"""Non-atomic hyperedge."""
145145
def __new__(cls, edges: Iterable[Hyperedge | None]) -> Hyperedge:
146146
return super(Hyperedge, cls).__new__(cls, tuple(edges))
@@ -206,7 +206,7 @@ def label(self) -> str:
206206
"""Generate human-readable label for edge."""
207207
conn_atom = self.connector_atom()
208208
if len(self) == 2:
209-
edge = self
209+
edge: tuple[Any, ...] = self
210210
elif conn_atom is not None and conn_atom.parts()[-1] == '.':
211211
edge = self[1:]
212212
else:
@@ -231,7 +231,7 @@ def inner_atom(self) -> Atom:
231231
232232
The inner atom of an atom is itself.
233233
"""
234-
return self[1].inner_atom()
234+
return self[1].inner_atom() # type: ignore[no-any-return]
235235

236236
def connector_atom(self) -> Atom | None:
237237
"""The inner atom of the connector.
@@ -243,7 +243,7 @@ def connector_atom(self) -> Atom | None:
243243
244244
The connector atom of an atom is None.
245245
"""
246-
return self[0].inner_atom()
246+
return self[0].inner_atom() # type: ignore[no-any-return]
247247

248248
def atoms(self) -> set[Atom]:
249249
"""Returns the set of atoms contained in the edge.
@@ -408,15 +408,15 @@ def type(self) -> str:
408408
elif ptype[0] == 'M':
409409
if len(self) < 2:
410410
raise RuntimeError('Edge is malformed, type cannot be determined: {}'.format(str(self)))
411-
return self[1].type()
411+
return self[1].type() # type: ignore[no-any-return]
412412
elif ptype[0] == 'T':
413413
outter_type = 'S'
414414
elif ptype[0] == 'B':
415415
outter_type = 'C'
416416
elif ptype[0] == 'J':
417417
if len(self) < 2:
418418
raise RuntimeError('Edge is malformed, type cannot be determined: {}'.format(str(self)))
419-
return self[1].mtype()
419+
return self[1].mtype() # type: ignore[no-any-return]
420420
else:
421421
raise RuntimeError('Edge is malformed, type cannot be determined: {}'.format(str(self)))
422422

@@ -427,7 +427,7 @@ def connector_type(self) -> str | None:
427427
If the edge has no connector (i.e. it's an atom), then None is
428428
returned.
429429
"""
430-
return self[0].type()
430+
return self[0].type() # type: ignore[no-any-return]
431431

432432
def mtype(self) -> str:
433433
"""Returns the main type of this edge as a string of one character.
@@ -458,7 +458,7 @@ def atom_with_type(self, atom_type: str) -> Atom | None:
458458
b/Cp
459459
"""
460460
for item in self:
461-
atom = item.atom_with_type(atom_type)
461+
atom: Atom | None = item.atom_with_type(atom_type)
462462
if atom:
463463
return atom
464464
return None
@@ -488,10 +488,10 @@ def argroles(self) -> str:
488488
"""
489489
et = self.mtype()
490490
if et in {'R', 'C'} and self[0].mtype() in {'B', 'P'}:
491-
return self[0].argroles()
491+
return self[0].argroles() # type: ignore[no-any-return]
492492
if et not in {'B', 'P'}:
493493
return ''
494-
return self[1].argroles()
494+
return self[1].argroles() # type: ignore[no-any-return]
495495

496496
def has_argroles(self) -> bool:
497497
"""Returns True if the edge has argroles, False otherwise."""
@@ -531,8 +531,8 @@ def insert_edge_with_argrole(self, edge: Hyperedge, argrole: str, pos: int) -> H
531531
"""Returns a new edge with the provided edge and its argroles inserted
532532
at the specified position."""
533533
new_edge = self.insert_argrole(argrole, pos)
534-
new_edge = new_edge[:pos + 1] + (edge,) + new_edge[pos + 1:]
535-
return Hyperedge(new_edge)
534+
combined = tuple(new_edge[:pos + 1]) + (edge,) + tuple(new_edge[pos + 1:])
535+
return Hyperedge(combined)
536536

537537
def edges_with_argrole(self, argrole: str) -> list[Hyperedge]:
538538
"""Returns the list of edges with the given argument role."""
@@ -665,26 +665,26 @@ def check_correctness(self) -> dict[Hyperedge, list[tuple[str, str]]]:
665665
return output
666666

667667
def normalized(self) -> Hyperedge | None:
668-
edge = self
668+
edge: Hyperedge = self
669669
conn = edge[0]
670670
ar = conn.argroles()
671671
if ar != '':
672672
if ar[0] == '{':
673673
ar = ar[1:-1]
674-
roles_edges = zip(ar, edge[1:])
675-
roles_edges = sorted(roles_edges, key=lambda role_edge: argrole_order[role_edge[0]])
676-
edge = hedge([conn] + list(role_edge[1] for role_edge in roles_edges))
677-
if not edge:
674+
roles_edges_sorted = sorted(zip(ar, edge[1:]), key=lambda role_edge: argrole_order[role_edge[0]])
675+
new_edge = hedge([conn] + list(role_edge[1] for role_edge in roles_edges_sorted))
676+
if not new_edge:
678677
return None
678+
edge = new_edge
679679
return hedge([subedge.normalized() for subedge in edge])
680680

681-
def __add__(self, other) -> Hyperedge:
682-
if type(other) in {tuple, list}:
683-
return Hyperedge(super(Hyperedge, self).__add__(other))
684-
elif other.atom:
685-
return Hyperedge(super(Hyperedge, self).__add__((other,)))
681+
def __add__(self, other: Hyperedge | tuple[Any, ...] | list[Any]) -> Hyperedge:
682+
if isinstance(other, (list, tuple)) and not isinstance(other, Hyperedge):
683+
return Hyperedge(tuple.__add__(self, tuple(other)))
684+
elif isinstance(other, Hyperedge) and other.atom:
685+
return Hyperedge(tuple.__add__(self, (other,)))
686686
else:
687-
return Hyperedge(super(Hyperedge, self).__add__(other))
687+
return Hyperedge(tuple.__add__(self, tuple(other)))
688688

689689
def __str__(self) -> str:
690690
return self.to_str()
@@ -730,7 +730,7 @@ def is_atom(self) -> bool:
730730

731731
def parts(self) -> list[str]:
732732
"""Splits atom into its parts."""
733-
return self[0].split('/')
733+
return self[0].split('/') # type: ignore[no-any-return]
734734

735735
def root(self) -> str:
736736
"""Extracts the root of an atom
@@ -847,7 +847,7 @@ def contains(self, needle: str, deep: bool = False) -> bool:
847847
848848
Keyword argument:
849849
deep -- search recursively (default: False)"""
850-
return self[0] == needle
850+
return self[0] == needle # type: ignore[no-any-return]
851851

852852
def subedges(self) -> set[Hyperedge]:
853853
"""Returns all the subedges contained in the edge, including atoms
@@ -905,7 +905,7 @@ def role(self) -> list[str]:
905905
906906
['J'].
907907
"""
908-
parts = self[0].split('/')
908+
parts: list[str] = self[0].split('/')
909909
if len(parts) < 2:
910910
return list('J')
911911
else:
@@ -1087,13 +1087,13 @@ def normalized(self) -> Atom:
10871087
return self.replace_argroles(ar)
10881088
return self
10891089

1090-
def __add__(self, other) -> Hyperedge:
1091-
if type(other) in {tuple, list}:
1092-
return Hyperedge((self,) + other)
1093-
elif other.atom:
1090+
def __add__(self, other: Hyperedge | tuple[Any, ...] | list[Any]) -> Hyperedge:
1091+
if isinstance(other, (list, tuple)) and not isinstance(other, Hyperedge):
1092+
return Hyperedge(tuple.__add__((self,), tuple(other)))
1093+
elif isinstance(other, Hyperedge) and other.atom:
10941094
return Hyperedge((self, other))
10951095
else:
1096-
return Hyperedge((self,) + other)
1096+
return Hyperedge(tuple.__add__((self,), tuple(other)))
10971097

10981098

10991099
class UniqueAtom(Atom):
@@ -1112,7 +1112,7 @@ def unique(edge: Hyperedge) -> Hyperedge | None:
11121112
if type(edge) == UniqueAtom:
11131113
return edge
11141114
else:
1115-
return UniqueAtom(edge)
1115+
return UniqueAtom(edge) # type: ignore[arg-type]
11161116
else:
11171117
return hedge([unique(subedge) for subedge in edge])
11181118

src/hyperbase/parsers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def get_parser(name: str, **kwargs: Any) -> Parser:
3333
f"Available parsers: {available}"
3434
)
3535
cls = parsers[name].load()
36-
return cls(**kwargs)
36+
return cls(**kwargs) # type: ignore[no-any-return]
3737

3838

3939
__all__ = ["Parser", "get_parser", "list_parsers"]

src/hyperbase/parsers/correctness.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
1-
from typing import List, Dict
1+
from __future__ import annotations
2+
23
from collections import Counter
4+
from typing import Any
35

46
from hyperbase.hyperedge import Hyperedge
57
from hyperbase.parsers.utils import filter_alphanumeric_strings
68

79

8-
def check_structural_quality(edge: Hyperedge) -> Dict:
9-
errors = {}
10-
11-
def _visit(current_edge):
10+
def check_structural_quality(edge: Hyperedge) -> dict[Hyperedge, list[tuple[str, str, int]]]:
11+
errors: dict[Hyperedge, list[tuple[str, str, int]]] = {}
12+
13+
def _visit(current_edge: Hyperedge) -> None:
1214
if not current_edge or current_edge.atom:
1315
return
1416

15-
current_errors = []
16-
17+
current_errors: list[tuple[str, str, int]] = []
18+
1719
# Argrole checks
1820
try:
1921
ars = current_edge.argroles()
20-
ar_counts = Counter()
22+
ar_counts: Counter[str] = Counter()
2123
for ar in ars:
2224
if ar not in 'mspaoixtjrc':
2325
current_errors.append(('bad-argrole', f"Bad argument role '{ar}'. Should be one of 'mspaoixtjrc'.", 2))
2426
ar_counts[ar] += 1
25-
27+
2628
for role in 'spoiamc':
2729
if ar_counts[role] > 1:
2830
current_errors.append((f'duplicate-argrole-{role}', f"Argument role '{role}' should only be used once.", 2))
@@ -51,34 +53,34 @@ def _visit(current_edge):
5153

5254
def badness_check(
5355
edge: Hyperedge,
54-
tokens: List[str]
55-
) -> Dict[str, List[str]]:
56+
tokens: list[str]
57+
) -> dict[Any, list[tuple[str, str, int]]]:
5658

5759
raw_errors = edge.check_correctness()
58-
errors = {}
60+
errors: dict[Any, list[tuple[str, str, int]]] = {}
5961
for k, v in raw_errors.items():
6062
errors[k] = [(err_type, err_msg, 0) for err_type, err_msg in v]
6163

6264
structural_errors = check_structural_quality(edge)
63-
for k, v in structural_errors.items():
65+
for k, v2 in structural_errors.items():
6466
if k in errors:
65-
errors[k].extend(v)
67+
errors[k].extend(v2)
6668
else:
67-
errors[k] = v
69+
errors[k] = v2
6870

6971
# Only check token matching if we have a valid edge
7072
if edge:
7173
try:
7274
tokens = filter_alphanumeric_strings(tokens)
7375
roots = filter_alphanumeric_strings([atom.label() for atom in edge.all_atoms()])
7476

75-
77+
7678
# Track which tokens and roots have been matched
77-
matched_tokens = set()
78-
matched_roots = set()
79+
matched_tokens: set[int] = set()
80+
matched_roots: set[int] = set()
7981

8082
# Count remaining unmatched instances of each root
81-
def count_unmatched_roots(root_value):
83+
def count_unmatched_roots(root_value: str) -> int:
8284
count = 0
8385
for root_idx, root in enumerate(roots):
8486
if root == root_value and root_idx not in matched_roots:
@@ -119,7 +121,7 @@ def count_unmatched_roots(root_value):
119121
continue # This root is already matched
120122

121123
concatenated = ""
122-
root_sequence = []
124+
root_sequence: list[int] = []
123125

124126
for root_idx in range(root_start_idx, len(roots)):
125127
if root_idx in matched_roots:
@@ -152,7 +154,7 @@ def count_unmatched_roots(root_value):
152154
continue # Already matched
153155

154156
concatenated = ""
155-
token_sequence = []
157+
token_sequence: list[int] = []
156158

157159
for next_token_idx in range(token_idx, len(tokens)):
158160
if next_token_idx in matched_tokens:
@@ -242,7 +244,7 @@ def count_unmatched_roots(root_value):
242244
if token_idx in matched_tokens:
243245
break # Found a match, no need to try other combinations
244246

245-
token_matching_errors = []
247+
token_matching_errors: list[tuple[str, str, int]] = []
246248
# Report unmatched roots
247249
for root_idx, root in enumerate(roots):
248250
if root_idx not in matched_roots:

src/hyperbase/parsers/parser.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,39 @@
1-
from typing import List
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterator
4+
from typing import Any
25

36

47
class Parser:
5-
def sentensize(self, text):
8+
def sentensize(self, text: str) -> list[str]:
69
raise NotImplementedError
710

8-
def parse(self, text):
11+
def parse(self, text: str) -> Iterator[dict[str, Any]]:
912
for sentence in self.sentensize(text):
1013
for parse in self.parse_sentence(sentence):
1114
yield parse
1215

13-
def parse_sentence(self, sentence):
16+
def parse_sentence(self, sentence: str) -> list[dict[str, Any]]:
1417
raise NotImplementedError
1518

16-
def parse_batch(self, sentences: List[str]) -> List[List[dict]]:
19+
def parse_batch(self, sentences: list[str]) -> list[list[dict[str, Any]]]:
1720
"""Parse multiple sentences. Subclasses may override with a
1821
true batched implementation (e.g. a single CT2 call)."""
1922
return [self.parse_sentence(sentence) for sentence in sentences]
2023

2124
def parse_text(
2225
self, text: str, batch_size: int = 8, progress: bool = False
23-
) -> List[dict]:
26+
) -> list[dict[str, Any]]:
2427
"""Sentensize text, then parse all sentences in batches.
2528
2629
Returns a flat list of parse results across all sentences.
2730
"""
2831
sentences = [s for s in self.sentensize(text) if len(s.split()) > 1]
2932
batch_range = range(0, len(sentences), batch_size)
3033
if progress:
31-
from tqdm import tqdm
34+
from tqdm import tqdm # type: ignore[import-untyped]
3235
batch_range = tqdm(batch_range, desc="Parsing batches", leave=False)
33-
results: List[dict] = []
36+
results: list[dict[str, Any]] = []
3437
for i in batch_range:
3538
batch = sentences[i:i + batch_size]
3639
for sentence_results in self.parse_batch(batch):

src/hyperbase/parsers/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from typing import List
2-
3-
4-
def filter_alphanumeric_strings(strings: List[str]) -> List[str]:
1+
def filter_alphanumeric_strings(strings: list[str]) -> list[str]:
52
"""
63
Filter a list of strings to include only those containing alphanumeric characters,
74
and remove all non-alphanumeric characters from each string.
@@ -12,7 +9,7 @@ def filter_alphanumeric_strings(strings: List[str]) -> List[str]:
129
Returns:
1310
Filtered list containing only lowercased alphanumeric characters
1411
"""
15-
filtered = []
12+
filtered: list[str] = []
1613
for s in strings:
1714
# Remove non-alphanumeric characters and lowercase
1815
cleaned = ''.join(c.lower() for c in s if c.isalnum())

0 commit comments

Comments
 (0)