-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtest_default.py
More file actions
241 lines (177 loc) · 8.19 KB
/
test_default.py
File metadata and controls
241 lines (177 loc) · 8.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
from error_align._cpp_beam_search import error_align_beam_search as cpp_error_align_beam_search
from typeguard import suppress_type_checks
from error_align import error_align
from error_align.backtrace_graph import BacktraceGraph
from error_align.beam_search import _cpp_path_to_py_path
from error_align.beam_search import error_align_beam_search as python_error_align_beam_search
from error_align.edit_distance import compute_error_align_distance_matrix, compute_levenshtein_distance_matrix
from error_align.error_align import prepare_graph_metadata
from error_align.graph_metadata import SubgraphMetadata
from error_align.utils import Alignment, OpType, basic_normalizer, categorize_char, ensure_length_preservation
def test_error_align() -> None:
"""Test error alignment for an example including all substitution types."""
ref = "This is a substitution test deleted."
hyp = "Inserted this is a contribution test."
alignments = error_align(ref, hyp)
expected_ops = [
OpType.INSERT, # Inserted
OpType.MATCH, # This
OpType.MATCH, # is
OpType.MATCH, # a
OpType.SUBSTITUTE, # contribution -> substitution
OpType.MATCH, # test
OpType.DELETE, # deleted
]
for op, alignment in zip(expected_ops, alignments, strict=True):
assert alignment.op_type == op
def test_beam_search_cpp_vs_python() -> None:
"""Test that the C++ and Python beam search implementations produce the same results."""
ref = "This is a substitution test deleted."
hyp = "Inserted this is a contribution test."
graph_metadata = prepare_graph_metadata(ref, hyp)
subgraph_metadata = SubgraphMetadata(
ref_raw=graph_metadata.ref_raw,
hyp_raw=graph_metadata.hyp_raw,
ref_token_matches=graph_metadata.ref_token_matches,
hyp_token_matches=graph_metadata.hyp_token_matches,
ref_norm=graph_metadata.ref_norm,
hyp_norm=graph_metadata.hyp_norm,
)
path_cpp = cpp_error_align_beam_search(src=subgraph_metadata)
path_cpp_converted = _cpp_path_to_py_path(path_cpp)
path_python = python_error_align_beam_search(src=subgraph_metadata)
assert path_cpp_converted.open_cost == path_python.open_cost
assert path_cpp_converted.closed_cost == path_python.closed_cost
assert path_cpp_converted.cost == path_python.cost
assert path_cpp_converted.norm_cost == path_python.norm_cost
assert path_cpp_converted.sort_id == path_python.sort_id
def test_error_align_identical() -> None:
"""Test error alignment for full match."""
ref = "This is a test."
hyp = "This is a test."
alignments = error_align(ref, hyp)
for alignment in alignments:
assert alignment.op_type == OpType.MATCH
def test_partial_substitution_and_insertion() -> None:
"""Test error alignment for partial substitutions and insertions with compound markers."""
ref = "test"
hyp = "testpartial"
alignments = error_align(ref, hyp)
assert len(alignments) == 2
assert alignments[0].op_type == OpType.SUBSTITUTE
assert alignments[0].left_compound is False
assert alignments[0].right_compound is True
assert alignments[1].op_type == OpType.INSERT
assert alignments[1].left_compound is True
assert alignments[1].right_compound is False
def test_categorize_char() -> None:
"""Test character categorization."""
assert categorize_char("<") == 0 # Delimiters
assert categorize_char("b") == 1 # Consonants
assert categorize_char("a") == 2 # Vowels
assert categorize_char("'") == 3 # Unvoiced characters
def test_representations() -> None:
"""Test the string representation of Alignment objects."""
# Test DELETE operation
delete_alignment = error_align("deleted", "")[0]
assert repr(delete_alignment) == 'Alignment(DELETE: "deleted")'
# Test INSERT operation with compound markers
insert_alignment = error_align("", "inserted")[0]
assert repr(insert_alignment) == 'Alignment(INSERT: "inserted")'
# Test SUBSTITUTE operation with compound markers
substitute_alignment = error_align("substitution", "substitutiontesting")[0]
assert substitute_alignment.left_compound is False
assert substitute_alignment.right_compound is True
assert repr(substitute_alignment) == 'Alignment(SUBSTITUTE: "substitution"- -> "substitution")'
# Test MATCH operation without compound markers
match_alignment = error_align("test", "test")[0]
assert repr(match_alignment) == 'Alignment(MATCH: "test" == "test")'
@suppress_type_checks
def test_input_type_checks() -> None:
"""Test input type checks for ErrorAlign class."""
try:
_ = error_align(ref=123, hyp="valid") # type: ignore
except TypeError as e:
assert str(e) == "Reference sequence must be a string."
try:
_ = error_align(ref="valid", hyp=456) # type: ignore
except TypeError as e:
assert str(e) == "Hypothesis sequence must be a string."
def test_backtrace_graph() -> None:
"""Test backtrace graph generation."""
ref = "This is a test."
hyp = "This is a pest."
# Create ErrorAlign instance and generate backtrace graph.
graph_metadata = prepare_graph_metadata(ref, hyp)
subgraph_metadata = SubgraphMetadata(
ref_raw=graph_metadata.ref_raw,
hyp_raw=graph_metadata.hyp_raw,
ref_token_matches=graph_metadata.ref_token_matches,
hyp_token_matches=graph_metadata.hyp_token_matches,
ref_norm=graph_metadata.ref_norm,
hyp_norm=graph_metadata.hyp_norm,
)
_, backtrace_matrix = compute_error_align_distance_matrix(
subgraph_metadata.ref,
subgraph_metadata.hyp,
backtrace=True,
)
graph = BacktraceGraph(backtrace_matrix)
# Check basic properties of the graph.
assert isinstance(graph.get_path(), list)
assert isinstance(graph.get_path(sample=True), list)
for index in graph._iter_topological_order():
assert isinstance(index, tuple)
def test_levenshtein_distance_matrix() -> None:
"""Test Levenshtein distance matrix computation."""
ref = "kitten"
hyp = "sitting"
distance_matrix = compute_levenshtein_distance_matrix(ref, hyp)
assert distance_matrix[-1][-1] == 3
def test_alignment_validation() -> None:
"""Test alignment validation for mismatched inputs."""
try:
Alignment(OpType.MATCH, ref="something", hyp=None)
raise AssertionError("Expected ValueError for MATCH with None hyp.")
except ValueError:
pass
try:
Alignment(OpType.INSERT, ref="something", hyp=None)
raise AssertionError("Expected ValueError for INSERT with None hyp.")
except ValueError:
pass
try:
Alignment(OpType.DELETE, ref=None, hyp="something")
raise AssertionError("Expected ValueError for DELETE with None ref.")
except ValueError:
pass
try:
Alignment(OpType.SUBSTITUTE, ref=None, hyp="something")
raise AssertionError("Expected ValueError for SUBSTITUTE with None ref.")
except ValueError:
pass
def test_alignment_representation() -> None:
"""Test the string representation of Alignment objects."""
alignment = Alignment(OpType.MATCH, ref="test", hyp="test")
assert repr(alignment) == 'Alignment(MATCH: "test" == "test")'
alignment = Alignment(OpType.INSERT, ref=None, hyp="inserted")
assert repr(alignment) == 'Alignment(INSERT: "inserted")'
alignment = Alignment(OpType.DELETE, ref="deleted", hyp=None)
assert repr(alignment) == 'Alignment(DELETE: "deleted")'
alignment = Alignment(OpType.SUBSTITUTE, ref="old", hyp="new")
assert repr(alignment) == 'Alignment(SUBSTITUTE: "new" -> "old")'
def test_normalization_guardrails() -> None:
"""Test normalization guardrails."""
def bad_normalizer(text: str) -> str:
return text + "_"
normalizer = ensure_length_preservation(bad_normalizer)
try:
normalizer("test")
raise AssertionError("Expected ValueError for length mismatch.")
except ValueError:
pass
def test_basic_normalizer_dotted_capital_i() -> None:
"""Regression test: U+0130 (İ) must not expand length when lowercased."""
result = basic_normalizer("İstanbul")
assert result == "istanbul"
assert len(result) == len("İstanbul")