Skip to content

Commit c22faca

Browse files
committed
Use Any for refs; make all_rows 'read-only'
1 parent def8b0a commit c22faca

4 files changed

Lines changed: 31 additions & 29 deletions

File tree

machine/corpora/n_parallel_text_corpus.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
class _RangeRow:
14-
refs: List[object]
14+
refs: List[Any]
1515
segment: List[str]
1616
is_sentence_start: bool = False
1717

@@ -49,8 +49,8 @@ def add_text_row(self, row: TextRow, index: int):
4949
self.rows[index].segment.extend(row.segment)
5050

5151
def create_row(self) -> NParallelTextRow:
52-
refs: List[List[object]] = [[] for _ in range(self.n)]
53-
reference_refs: List[object] = [r.refs[0] if len(r.refs) > 0 else None for r in self.rows if len(r.refs) > 0]
52+
refs: List[List[Any]] = [[] for _ in range(self.n)]
53+
reference_refs: List[Any] = [r.refs[0] if len(r.refs) > 0 else None for r in self.rows if len(r.refs) > 0]
5454
for i in range(len(self.rows)):
5555
row = self.rows[i]
5656

@@ -76,12 +76,10 @@ def create_row(self) -> NParallelTextRow:
7676

7777

7878
class NParallelTextCorpus(NParallelTextCorpusBase):
79-
def __init__(
80-
self, corpora: Sequence[TextCorpus], row_ref_comparer: Optional[Callable[[object, object], int]] = None
81-
):
79+
def __init__(self, corpora: Sequence[TextCorpus], row_ref_comparer: Optional[Callable[[Any, Any], int]] = None):
8280
self._corpora = corpora
8381
self._row_ref_comparer = row_ref_comparer if row_ref_comparer is not None else default_row_ref_comparer
84-
self.all_rows: Sequence[bool] = tuple(False for _ in range(len(corpora)))
82+
self._all_rows: Sequence[bool] = [False for _ in range(len(corpora))]
8583

8684
def is_tokenized(self, i: int) -> bool:
8785
return self.corpora[i].is_tokenized
@@ -95,9 +93,13 @@ def corpora(self) -> Sequence[TextCorpus]:
9593
return list(self._corpora)
9694

9795
@property
98-
def row_ref_comparer(self) -> Callable[[object, object], int]:
96+
def row_ref_comparer(self) -> Callable[[Any, Any], int]:
9997
return self._row_ref_comparer
10098

99+
@property
100+
def all_rows(self) -> Sequence[bool]:
101+
return self._all_rows
102+
101103
def _get_text_ids_from_corpora(self) -> Set[str]:
102104
text_ids: Set[str] = set()
103105
all_rows_text_ids: Set[str] = set()
@@ -107,7 +109,7 @@ def _get_text_ids_from_corpora(self) -> Set[str]:
107109
else:
108110
text_ids = text_ids.intersection({text.id for text in self.corpora[i].texts})
109111

110-
if self.all_rows[i]:
112+
if self._all_rows[i]:
111113
all_rows_text_ids = all_rows_text_ids.union({text.id for text in self.corpora[i].texts})
112114
text_ids = text_ids.union(all_rows_text_ids)
113115
return text_ids
@@ -129,7 +131,7 @@ def get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Iterable[NParall
129131
def _all_ranges_are_new_ranges(rows: List[TextRow]):
130132
return all([True if row.is_range_start or not row.is_in_range else False for row in rows])
131133

132-
def _min_ref_indexes(self, refs: Sequence[object]) -> Sequence[int]:
134+
def _min_ref_indexes(self, refs: Sequence[Any]) -> Sequence[int]:
133135
min_ref = refs[0]
134136
min_ref_indexes = [0]
135137
for i in range(len(refs)):
@@ -181,7 +183,7 @@ def _get_rows(self, generators: List[TextCorpusEnumerator]) -> Iterable[NParalle
181183
or len([i for i in min_ref_indexes if not completed[i]]) == 1
182184
):
183185
# then there are some non-min refs or only one incomplete generator
184-
if any([not self.all_rows[i] for i in non_min_ref_indexes]) and any(
186+
if any([not self._all_rows[i] for i in non_min_ref_indexes]) and any(
185187
[not completed[i] and current_rows[i].is_in_range for i in min_ref_indexes]
186188
):
187189
# At least one of the non-min rows has not been marked as 'all rows'
@@ -231,7 +233,7 @@ def _get_rows(self, generators: List[TextCorpusEnumerator]) -> Iterable[NParalle
231233
if any(
232234
[
233235
current_rows[i].is_in_range
234-
and all([j == i or not self.all_rows[j] for j in min_ref_indexes])
236+
and all([j == i or not self._all_rows[j] for j in min_ref_indexes])
235237
for i in min_ref_indexes
236238
]
237239
):
@@ -271,7 +273,7 @@ def _get_rows(self, generators: List[TextCorpusEnumerator]) -> Iterable[NParalle
271273
if range_info.is_in_range:
272274
yield range_info.create_row()
273275

274-
def _correct_versification(self, refs: List[object], i: int) -> List[object]:
276+
def _correct_versification(self, refs: List[Any], i: int) -> List[Any]:
275277
if any([not c.is_scripture for c in self.corpora]) or len(refs) == 0:
276278
return refs
277279
return [
@@ -288,7 +290,7 @@ def _create_rows(
288290
default_refs = [[r.ref for r in rows if r is not None][0]]
289291

290292
text_id: Optional[str] = None
291-
refs: List[List[object]] = []
293+
refs: List[List[Any]] = []
292294
flags: List[TextRowFlags] = []
293295
for i in range(self.n):
294296
refs.append([])
@@ -342,7 +344,7 @@ def _create_min_ref_rows(
342344
text_rows = [None for _ in range(self.n)]
343345
rows_have_content = False
344346
for i in min_ref_indexes:
345-
if not self.all_rows[i] or i in already_yielded:
347+
if not self._all_rows[i] or i in already_yielded:
346348
continue
347349
text_row = current_rows[i]
348350
text_rows[i] = text_row

machine/corpora/n_parallel_text_row.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from typing import Sequence
1+
from typing import Any, Sequence
22

33
from .text_row import TextRowFlags
44

55

66
class NParallelTextRow:
7-
def __init__(self, text_id: str, n_refs: Sequence[Sequence[object]]):
7+
def __init__(self, text_id: str, n_refs: Sequence[Sequence[Any]]):
88
if len([n_ref for n_ref in n_refs if n_ref is not None and len(n_ref) > 0]) == 0:
99
raise ValueError(f"Refs must be provided but n_refs={n_refs}")
1010
self._text_id = text_id
@@ -18,11 +18,11 @@ def text_id(self) -> str:
1818
return self._text_id
1919

2020
@property
21-
def ref(self) -> object:
21+
def ref(self) -> Any:
2222
return self._n_refs[0][0]
2323

2424
@property
25-
def n_refs(self) -> Sequence[Sequence[object]]:
25+
def n_refs(self) -> Sequence[Sequence[Any]]:
2626
return self._n_refs
2727

2828
def is_sentence_start(self, i: int) -> bool:

machine/corpora/standard_parallel_text_corpus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
self._all_source_rows = all_source_rows
2828
self._all_target_rows = all_target_rows
2929
self._n_parallel_text_corpus = NParallelTextCorpus([source_corpus, target_corpus])
30-
self._n_parallel_text_corpus.all_rows = [self._all_source_rows, self.all_target_rows]
30+
self._n_parallel_text_corpus._all_rows = [self._all_source_rows, self.all_target_rows]
3131
self._row_ref_comparer = row_ref_comparer or default_row_ref_comparer
3232

3333
@property

tests/corpora/test_n_parallel_text_corpus.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_get_rows_three_corpora_missing_rows_all_all_rows():
115115
)
116116
)
117117
n_parallel_corpus = NParallelTextCorpus([corpus1, corpus2, corpus3])
118-
n_parallel_corpus.all_rows = [True, True, True]
118+
n_parallel_corpus._all_rows = [True, True, True]
119119
rows = list(n_parallel_corpus.get_rows())
120120
assert len(rows) == 3
121121
assert all([r[0] == 3 for r in rows[2].n_refs])
@@ -152,7 +152,7 @@ def test_get_rows_three_corpora_missing_rows_some_all_rows():
152152
)
153153
)
154154
n_parallel_corpus = NParallelTextCorpus([corpus1, corpus2, corpus3])
155-
n_parallel_corpus.all_rows = [True, False, True]
155+
n_parallel_corpus._all_rows = [True, False, True]
156156
rows = list(n_parallel_corpus.get_rows())
157157
assert len(rows) == 2
158158
assert all([r[0] == 3 for r in rows[1].n_refs])
@@ -192,7 +192,7 @@ def test_get_rows_three_corpora_missing_rows_all_all_rows_missing_middle():
192192
)
193193
)
194194
n_parallel_corpus = NParallelTextCorpus([corpus1, corpus2, corpus3])
195-
n_parallel_corpus.all_rows = [True, True, True]
195+
n_parallel_corpus._all_rows = [True, True, True]
196196
rows = list(n_parallel_corpus.get_rows())
197197
assert len(rows) == 3
198198
assert all([len(r) == 0 or r[0] == 2 for r in rows[1].n_refs])
@@ -228,7 +228,7 @@ def test_get_rows_three_corpora_missing_rows_missing_last_rows():
228228
)
229229
)
230230
n_parallel_corpus = NParallelTextCorpus([corpus1, corpus2, corpus3])
231-
n_parallel_corpus.all_rows = [True, False, False]
231+
n_parallel_corpus._all_rows = [True, False, False]
232232
rows = list(n_parallel_corpus.get_rows())
233233
assert len(rows) == 3
234234
assert all([r[0] == 2 for r in rows[1].n_refs])
@@ -247,7 +247,7 @@ def test_get_rows_three_corpora_one_corpus():
247247
)
248248
)
249249
n_parallel_corpus = NParallelTextCorpus([corpus1])
250-
n_parallel_corpus.all_rows = [True]
250+
n_parallel_corpus._all_rows = [True]
251251
rows = list(n_parallel_corpus.get_rows())
252252
assert len(rows) == 2
253253
assert all([r[0] == 1 for r in rows[0].n_refs])
@@ -389,7 +389,7 @@ def test_get_rows_three_corpora_overlapping_ranges_all_individual_rows():
389389
)
390390
)
391391
n_parallel_corpus = NParallelTextCorpus([corpus1, corpus2, corpus3])
392-
n_parallel_corpus.all_rows = [False, False, True]
392+
n_parallel_corpus._all_rows = [False, False, True]
393393
rows = list(n_parallel_corpus.get_rows())
394394
assert len(rows) == 3
395395
assert rows[0].n_refs[0] == [1]
@@ -437,7 +437,7 @@ def test_get_rows_three_corpora_overlapping_ranges_all_one_through_two_rows():
437437
)
438438
)
439439
n_parallel_corpus = NParallelTextCorpus([corpus1, corpus2, corpus3])
440-
n_parallel_corpus.all_rows = [False, True, False]
440+
n_parallel_corpus._all_rows = [False, True, False]
441441
rows = list(n_parallel_corpus.get_rows())
442442
assert len(rows) == 2
443443
assert rows[0].n_refs[0] == [1, 2]
@@ -485,7 +485,7 @@ def test_get_rows_three_corpora_overlapping_ranges_all_two_through_three_rows():
485485
)
486486
)
487487
n_parallel_corpus = NParallelTextCorpus([corpus1, corpus2, corpus3])
488-
n_parallel_corpus.all_rows = [True, False, False]
488+
n_parallel_corpus._all_rows = [True, False, False]
489489
rows = list(n_parallel_corpus.get_rows())
490490
assert len(rows) == 2
491491
assert rows[0].n_refs[0] == [1]
@@ -559,7 +559,7 @@ def test_get_rows_three_corpora_same_ref_corpora_of_different_sizes():
559559
)
560560
)
561561
n_parallel_corpus = NParallelTextCorpus([corpus1, corpus2, corpus3])
562-
n_parallel_corpus.all_rows = [True, True, True]
562+
n_parallel_corpus._all_rows = [True, True, True]
563563
rows = list(n_parallel_corpus.get_rows())
564564
assert len(rows) == 4
565565
assert rows[0].n_refs[1] == [1]

0 commit comments

Comments
 (0)