Skip to content

Commit 6b4322d

Browse files
committed
Port NParallelTextCorpus
1 parent 68cb79d commit 6b4322d

6 files changed

Lines changed: 538 additions & 1 deletion

File tree

machine/corpora/corpus.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
from itertools import islice
33
from typing import Any, Callable, Generator, Generic, Iterable, Optional, Sequence, Tuple, TypeVar
44

5+
from .n_parallel_text_row import NParallelTextRow
6+
57
from ..utils.context_managed_generator import ContextManagedGenerator
68
from .alignment_row import AlignmentRow
79
from .corpora_utils import batch, get_split_indices
810
from .parallel_text_row import ParallelTextRow
911
from .text_row import TextRow
1012

11-
Row = TypeVar("Row", TextRow, ParallelTextRow, AlignmentRow)
13+
Row = TypeVar("Row", TextRow, ParallelTextRow, AlignmentRow, NParallelTextRow)
1214
Item = TypeVar("Item")
1315

1416

Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
from typing import Callable, Generator, Iterable, List, Optional, Sequence, Set, cast
2+
3+
from ..scripture.verse_ref import Versification
4+
5+
from .text_corpus_enumerator import _TextCorpusEnumerator
6+
7+
from .text_row import TextRow, TextRowFlags
8+
9+
from .n_parallel_text_row import NParallelTextRow
10+
11+
from .scripture_ref import ScriptureRef
12+
13+
from .text_corpus import TextCorpus
14+
from .n_parallel_text_corpus_base import NParallelTextCorpusBase
15+
16+
17+
class _RangeRow:
18+
refs: List[object] = []
19+
segment: List[str] = []
20+
is_sentence_start = False
21+
22+
@property
23+
def is_in_range(self):
24+
return len(self.refs) > 0
25+
26+
@property
27+
def is_empty(self):
28+
return len(self.segment) == 0
29+
30+
31+
class _NRangeInfo:
32+
def __init__(self, n: int):
33+
self.n = n
34+
self.rows: List[_RangeRow] = []
35+
for _ in range(n):
36+
self.rows.append(_RangeRow())
37+
self.text_id = ""
38+
self.versifications: Optional[List[Versification]] = None
39+
self.row_ref_comparer = None
40+
41+
@property
42+
def is_in_range(self) -> bool:
43+
return any([row.is_in_range for row in self.rows])
44+
45+
def add_text_row(self, row: TextRow, index: int):
46+
self.text_id = row.text_id
47+
self.rows[index].refs.append(row.ref)
48+
if self.rows[index].is_empty:
49+
self.rows[index].is_sentence_start = row.is_sentence_start
50+
self.rows[index].segment.extend(row.segment)
51+
52+
def create_row(self) -> NParallelTextRow:
53+
refs: List[List[object]] = [[] for _ in range(self.n)]
54+
reference_refs: List[object] = [r.refs[0] if len(r.refs) > 0 else None for r in self.rows if len(r.refs) > 0]
55+
for i in range(len(self.rows)):
56+
row = self.rows[i]
57+
58+
if (
59+
self.versifications is not None
60+
and all([v is not None for v in self.versifications])
61+
and len(row.refs) == 0
62+
):
63+
refs[i] = [cast(ScriptureRef, r).change_versification(self.versifications[i]) for r in reference_refs]
64+
else:
65+
refs[i] = row.refs
66+
n_parallel_text_row = NParallelTextRow(self.text_id, refs)
67+
n_parallel_text_row.n_segments = [r.segment for r in self.rows]
68+
n_parallel_text_row.n_flags = [
69+
TextRowFlags.SENTENCE_START if r.is_sentence_start else TextRowFlags.NONE for r in self.rows
70+
]
71+
self.text_id = ""
72+
for row in self.rows:
73+
row.refs.clear()
74+
row.segment.clear()
75+
row.is_sentence_start = False
76+
return n_parallel_text_row
77+
78+
79+
class NParallelTextCorpus(NParallelTextCorpusBase):
80+
def __init__(
81+
self, corpora: Sequence[TextCorpus], row_ref_comparer: Optional[Callable[[object, object], int]] = None
82+
):
83+
self._corpora = corpora
84+
self._row_ref_comparer = row_ref_comparer if row_ref_comparer is not None else default_row_ref_comparer
85+
self.all_rows = [False for _ in range(len(corpora))]
86+
87+
def is_tokenized(self, i: int) -> bool:
88+
return self.corpora[i].is_tokenized
89+
90+
@property
91+
def n(self) -> int:
92+
return len(self.corpora)
93+
94+
@property
95+
def corpora(self) -> List[TextCorpus]:
96+
return list(self._corpora)
97+
98+
@property
99+
def row_ref_comparer(self) -> Callable[[object, object], int]:
100+
return self._row_ref_comparer
101+
102+
def _get_text_ids_from_corpora(self) -> Set[str]:
103+
text_ids: Set[str] = set()
104+
all_rows_text_ids: Set[str] = set()
105+
for i in range(self.n):
106+
if i == 0:
107+
text_ids = text_ids.union({text.id for text in self.corpora[i].texts})
108+
else:
109+
text_ids = text_ids.intersection({text.id for text in self.corpora[i].texts})
110+
111+
if self.all_rows[i]:
112+
all_rows_text_ids = all_rows_text_ids.union({text.id for text in self.corpora[i].texts})
113+
text_ids = text_ids.union(all_rows_text_ids)
114+
return text_ids
115+
116+
def get_rows(self, text_ids: Optional[Sequence[str]]) -> Iterable[NParallelTextRow]:
117+
filter_text_ids = self._get_text_ids_from_corpora()
118+
if text_ids is not None:
119+
filter_text_ids = filter_text_ids.intersection(text_ids)
120+
enumerated_corpora: List[_TextCorpusEnumerator] = []
121+
for i in range(self.n):
122+
enumerator = iter(self.corpora[i].get_rows(filter_text_ids))
123+
enumerated_corpora.append(
124+
_TextCorpusEnumerator(enumerator, self.corpora[0].versification, self.corpora[i].versification)
125+
)
126+
for row in self._get_rows(enumerated_corpora):
127+
yield row
128+
129+
@staticmethod
130+
def _all_ranges_are_new_ranges(rows: List[TextRow]):
131+
return all([True if row.is_range_start or not row.is_in_range else False for row in rows])
132+
133+
def _min_ref_indexes(self, refs: Sequence[object]) -> Sequence[int]:
134+
min_ref = refs[0]
135+
min_ref_indexes = [0]
136+
for i in range(len(refs)):
137+
if self.row_ref_comparer(refs[i], min_ref) < 0:
138+
min_ref = refs[i]
139+
min_ref_indexes = [i]
140+
elif self.row_ref_comparer(refs[i], min_ref) == 0:
141+
min_ref_indexes.append(i)
142+
return min_ref_indexes
143+
144+
def _get_rows(self, enumerators: List[_TextCorpusEnumerator]) -> Iterable[NParallelTextRow]:
145+
range_info = _NRangeInfo(self.n)
146+
same_ref_rows: List[List[TextRow]] = []
147+
for _ in range(self.n):
148+
same_ref_rows.append([])
149+
150+
completed = [False for _ in range(self.n)]
151+
num_completed = 0
152+
for i in range(self.n):
153+
is_completed = not bool(enumerators[i].move_next())
154+
completed[i] = is_completed
155+
if is_completed:
156+
num_completed += 1
157+
num_remaining_rows = self.n - num_completed
158+
159+
while num_completed < self.n:
160+
current_rows = [cast(TextRow, e.current) for e in enumerators]
161+
refs = []
162+
for i, row in enumerate(current_rows):
163+
if row is not None and not completed[i]:
164+
refs.append(row.ref)
165+
else:
166+
refs.append(None)
167+
min_ref_indexes = self._min_ref_indexes(refs)
168+
non_min_ref_indexes = list(set(range(0, self.n)).difference(min_ref_indexes))
169+
if len(min_ref_indexes) < num_remaining_rows or len([i for i in min_ref_indexes if not completed[i]]):
170+
# then there are some non-min refs or only one incomplete enumerator
171+
if any([not self.all_rows[i] for i in non_min_ref_indexes]) and any(
172+
[not completed[i] and current_rows[i].is_in_range for i in min_ref_indexes]
173+
):
174+
# At least one of the non-min rows has not been marked as 'all rows'
175+
# and at least one of the min rows is not completed and in a range
176+
for i in min_ref_indexes:
177+
range_info.add_text_row(cast(TextRow, enumerators[i].current), i)
178+
for i in non_min_ref_indexes:
179+
same_ref_rows[i].clear()
180+
else:
181+
any_non_min_enumerators_mid_range = any(
182+
[not completed[i] and not current_rows[i].is_range_start and current_rows[i].is_in_range]
183+
)
184+
for row in self._create_min_ref_rows(
185+
range_info,
186+
current_rows,
187+
min_ref_indexes,
188+
non_min_ref_indexes,
189+
same_ref_rows,
190+
[
191+
i in min_ref_indexes
192+
and any_non_min_enumerators_mid_range
193+
and all(
194+
[
195+
not completed[j] and current_rows[j].text_id == current_rows[i].text_id
196+
for j in non_min_ref_indexes
197+
]
198+
)
199+
],
200+
):
201+
yield row
202+
for i in min_ref_indexes:
203+
if completed[i]:
204+
continue
205+
same_ref_rows[i].append(cast(TextRow, enumerators[i].current))
206+
is_completed = not enumerators[i].move_next()
207+
completed[i] = is_completed
208+
if is_completed:
209+
num_completed += 1
210+
num_remaining_rows -= 1
211+
212+
elif len(min_ref_indexes) == num_remaining_rows:
213+
# the refs are all the same
214+
if any(
215+
[
216+
current_rows[i].is_in_range and all([j == i or not self.all_rows[j] for j in min_ref_indexes])
217+
for i in min_ref_indexes
218+
]
219+
):
220+
# At least one row is in range while the other rows are all not marked as 'all rows'
221+
if range_info.is_in_range and NParallelTextCorpus._all_ranges_are_new_ranges(
222+
[row for (i, row) in enumerate(current_rows) if not completed[i]]
223+
):
224+
yield range_info.create_row()
225+
226+
for i in range(len(range_info.rows)):
227+
if completed[i]:
228+
continue
229+
range_info.add_text_row(current_rows[i], i)
230+
same_ref_rows[i].clear()
231+
else:
232+
for row in self._create_same_ref_rows(range_info, completed, current_rows, same_ref_rows):
233+
yield row
234+
235+
for row in self._create_rows(
236+
range_info, [r if completed[i] else None for (i, r) in enumerate(current_rows)]
237+
):
238+
yield row
239+
240+
for i in range(len(range_info.rows)):
241+
if completed[i]:
242+
continue
243+
same_ref_rows[i].append(current_rows[i])
244+
is_completed = not enumerators[i].move_next()
245+
completed[i] = is_completed
246+
if is_completed:
247+
num_completed += 1
248+
num_remaining_rows -= 1
249+
250+
if range_info.is_in_range:
251+
yield range_info.create_row()
252+
253+
def _correct_versification(self, refs: List[object], i: int) -> List[object]:
254+
if any([not c.is_scripture for c in self.corpora]) or len(refs) == 0:
255+
return refs
256+
return [cast(ScriptureRef, ref).change_versification(self.corpora[i].versification) for ref in refs]
257+
258+
def _create_rows(
259+
self, range_info: _NRangeInfo, rows: List[Optional[TextRow]], force_in_range: Optional[Sequence[bool]] = None
260+
) -> Iterable[NParallelTextRow]:
261+
if range_info.is_in_range:
262+
yield range_info.create_row()
263+
264+
default_refs = [r.ref for r in rows if r is not None][0]
265+
text_id: Optional[str] = None
266+
refs: List[List[object]] = []
267+
flags: List[TextRowFlags] = []
268+
for i in range(self.n):
269+
refs.append([])
270+
flags[i] = TextRowFlags.NONE
271+
for i in range(len(rows)):
272+
row = rows[i]
273+
if row is not None:
274+
text_id = text_id or row.text_id
275+
if self.corpora[i].is_scripture:
276+
row = self._correct_versification([row.ref] if row.ref is None else default_refs, i)
277+
else:
278+
refs[i] = default_refs
279+
else:
280+
if self.corpora[i].is_scripture:
281+
refs[i] = self._correct_versification(default_refs, i)
282+
else:
283+
refs[i] = default_refs
284+
flags[i] = (
285+
TextRowFlags.IN_RANGE if force_in_range is not None and force_in_range[i] else TextRowFlags.NONE
286+
)
287+
refs = [r or default_refs for r in refs]
288+
289+
new_row = NParallelTextRow(cast(str, text_id), refs)
290+
new_row.n_segments = [r.segment if r is not None else [] for r in rows]
291+
new_row.n_flags = flags
292+
yield new_row
293+
294+
def _create_min_ref_rows(
295+
self,
296+
range_info: _NRangeInfo,
297+
current_rows: Sequence[TextRow],
298+
min_ref_indexes: Sequence[int],
299+
non_min_ref_indexes: Sequence[int],
300+
same_ref_rows_per_index: Sequence[List[TextRow]],
301+
force_in_range: Optional[Sequence[bool]],
302+
) -> Iterable[NParallelTextRow]:
303+
already_yielded: Set[int] = set()
304+
text_rows: List[Optional[TextRow]] = [None for _ in range(self.n)]
305+
for i in min_ref_indexes:
306+
text_row = current_rows[i]
307+
for j in non_min_ref_indexes:
308+
same_ref_rows = same_ref_rows_per_index[j]
309+
if self._check_same_ref_rows(same_ref_rows, text_row):
310+
already_yielded.add(i)
311+
for same_ref_row in same_ref_rows:
312+
text_rows[i] = text_row
313+
text_rows[j] = same_ref_row
314+
for row in self._create_rows(range_info, text_rows, force_in_range):
315+
yield row
316+
text_rows = [None for _ in range(self.n)]
317+
rows_have_content = False
318+
for i in min_ref_indexes:
319+
if not self.all_rows[i] or i in already_yielded:
320+
continue
321+
text_row = current_rows[i]
322+
text_rows[i] = text_row
323+
rows_have_content = True
324+
325+
if rows_have_content:
326+
for row in self._create_rows(range_info, text_rows, force_in_range):
327+
yield row
328+
329+
def _check_same_ref_rows(self, same_ref_rows: List[TextRow], other_row: TextRow) -> bool:
330+
if len(same_ref_rows) > 0 and self.row_ref_comparer(same_ref_rows[0], other_row.ref) != 0:
331+
same_ref_rows.clear()
332+
return len(same_ref_rows) > 0
333+
334+
def _create_same_ref_rows(
335+
self,
336+
range_info: _NRangeInfo,
337+
completed: Sequence[int],
338+
current_rows: Sequence[TextRow],
339+
same_ref_rows: Sequence[List[TextRow]],
340+
) -> Iterable[NParallelTextRow]:
341+
for i in range(self.n):
342+
if completed[i]:
343+
continue
344+
for j in range(self.n):
345+
if i == j or completed[j]:
346+
continue
347+
348+
if self._check_same_ref_rows(same_ref_rows[i], current_rows[j]):
349+
for tr in same_ref_rows[i]:
350+
text_rows: List[Optional[TextRow]] = [None for _ in range(self.n)]
351+
text_rows[i] = tr
352+
text_rows[j] = current_rows[j]
353+
for r in self._create_rows(range_info, text_rows):
354+
yield r
355+
356+
357+
@staticmethod
358+
def default_row_ref_comparer(x: object, y: object) -> int:
359+
# Do not use the default comparer for ScriptureRef, since we want to ignore segments
360+
if isinstance(x, ScriptureRef) and isinstance(y, ScriptureRef):
361+
return x.compare_to(y, False)
362+
if x is None and y is not None:
363+
return 1
364+
if x is not None and y is None:
365+
return -1
366+
if x == y:
367+
return 0
368+
if x < y: # type: ignore
369+
return -1
370+
return 1

0 commit comments

Comments
 (0)