Skip to content

Commit c1cc3f4

Browse files
committed
Passing parallel corpus tests
1 parent 6b4322d commit c1cc3f4

7 files changed

Lines changed: 231 additions & 580 deletions

machine/corpora/corpus.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
from itertools import islice
33
from typing import Any, Callable, Generator, Generic, Iterable, Optional, Sequence, Tuple, TypeVar
44

5-
from .n_parallel_text_row import NParallelTextRow
6-
75
from ..utils.context_managed_generator import ContextManagedGenerator
86
from .alignment_row import AlignmentRow
97
from .corpora_utils import batch, get_split_indices
8+
from .n_parallel_text_row import NParallelTextRow
109
from .parallel_text_row import ParallelTextRow
1110
from .text_row import TextRow
1211

machine/corpora/n_parallel_text_corpus.py

Lines changed: 140 additions & 117 deletions
Large diffs are not rendered by default.

machine/corpora/n_parallel_text_corpus_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Optional, Sequence
2+
from typing import List, Sequence
33

4-
from .text_corpus import TextCorpus
5-
from .n_parallel_text_row import NParallelTextRow
64
from .corpus import Corpus
5+
from .n_parallel_text_row import NParallelTextRow
6+
from .text_corpus import TextCorpus
77

88

99
class NParallelTextCorpusBase(Corpus[NParallelTextRow], ABC):

machine/corpora/n_parallel_text_row.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from typing import List, Sequence
1+
from typing import Sequence
22

33
from .text_row import TextRowFlags
44

55

66
class NParallelTextRow:
77
def __init__(self, text_id: str, n_refs: Sequence[Sequence[object]]):
8-
if len([n_ref for n_ref in n_refs if n_ref is not None]) == 0:
8+
if len([n_ref for n_ref in n_refs if n_ref is not None and len(n_ref) > 0]) == 0:
99
raise ValueError(f"Refs must be provided but n_refs={n_refs}")
1010
self._text_id = text_id
1111
self._n_refs = n_refs

machine/corpora/standard_parallel_text_corpus.py

Lines changed: 32 additions & 402 deletions
Large diffs are not rendered by default.
Lines changed: 51 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,93 @@
1-
from typing import Generator, List, Optional, Tuple, cast, overload
2-
3-
from ..utils.context_managed_generator import ContextManagedGenerator
4-
5-
from .scripture_ref import ScriptureRef
1+
from queue import SimpleQueue
2+
from typing import Any, Generator, List, Optional, Tuple, cast, ContextManager
63

74
from ..scripture.verse_ref import Versification
5+
from .scripture_ref import EMPTY_SCRIPTURE_REF, ScriptureRef
86
from .text_row import TextRow, TextRowFlags
97

108

11-
class _TextCorpusEnumerator(ContextManagedGenerator[TextRow, None, None]):
9+
class _TextCorpusEnumerator(ContextManager["_TextCorpusEnumerator"], Generator[TextRow, None, None]):
1210
def __init__(
13-
self, enumerator: Generator[TextRow, None, None], ref_versification: Versification, versification: Versification
11+
self, generator: Generator[TextRow, None, None], ref_versification: Versification, versification: Versification
1412
):
15-
self._enumerator = enumerator
13+
self._generator = generator
1614
self._ref_versification = ref_versification
1715
self._is_scripture = (
1816
ref_versification is not None and versification is not None and ref_versification != versification
1917
)
20-
self._verse_rows: List[TextRow] = []
2118
self._is_enumerating = False
22-
self._enumerator_has_more_data = True
23-
self._current: Optional[TextRow] = None
24-
25-
def __iter__(self):
26-
return self
27-
28-
def __next__(self) -> TextRow:
29-
if not self.move_next() or self._current is None:
30-
raise StopIteration
31-
return self._current
32-
33-
@property
34-
def current(self):
35-
return self._current
19+
self._verse_rows: SimpleQueue[TextRow] = SimpleQueue()
20+
self._row: Optional[TextRow] = None
3621

37-
def move_next(self) -> bool:
22+
def send(self, value: None) -> TextRow:
3823
if self._is_scripture:
3924
if not self._is_enumerating:
40-
self._current = self._enumerator.__next__()
25+
self._row = next(self._generator, None)
4126
self._is_enumerating = True
42-
if len(self._verse_rows) == 0 and self._current is not None and self._enumerator_has_more_data:
27+
if self._verse_rows.empty() and self._row is not None:
4328
self._collect_verses()
44-
if len(self._verse_rows) > 0:
45-
self._current = self._verse_rows.pop(0)
46-
return True
47-
self._current = None
48-
return False
29+
if not self._verse_rows.empty():
30+
return self._verse_rows.get()
31+
raise StopIteration
32+
33+
self._row = next(self._generator, None)
34+
if self._row is not None:
35+
return self._row
36+
raise StopIteration
37+
38+
def throw(self, type: Any, value: Any = None, traceback: Any = None) -> TextRow:
39+
raise StopIteration
4940

50-
self._current = self._enumerator.__next__()
51-
self._enumerator_has_more_data = self._current != None
52-
return self._enumerator_has_more_data
41+
def close(self) -> None:
42+
super().close()
43+
self._generator.close()
5344

54-
# Not porting reset() since it is unused
45+
def __enter__(self) -> "_TextCorpusEnumerator":
46+
return self
47+
48+
def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
49+
self.close()
5550

5651
def _collect_verses(self):
52+
assert self._ref_versification is not None
5753
rows: List[Tuple[ScriptureRef, TextRow]] = []
58-
out_of_order: bool = False
59-
prev_ref = ScriptureRef._empty
60-
range_start_offset: int = -1
61-
while True:
62-
row = cast(TextRow, self._current)
54+
out_of_order = False
55+
prev_ref = EMPTY_SCRIPTURE_REF
56+
range_start_offset = -1
57+
while self._row is not None:
58+
row = cast(TextRow, self._row)
6359
ref = cast(ScriptureRef, row.ref)
64-
if prev_ref is not None and not prev_ref.is_empty and ref.book_num != prev_ref.book_num:
60+
if not prev_ref.is_empty and ref.book_num != prev_ref.book_num:
6561
break
6662

6763
ref = ref.change_versification(self._ref_versification)
64+
# convert one-to-many mapping to a verse range
6865
if ref == prev_ref:
69-
range_start_ref, range_start_row = rows[len(rows) + range_start_offset]
66+
range_start_ref, range_start_row = rows[range_start_offset]
7067
flags = TextRowFlags.IN_RANGE
7168
if range_start_row.is_sentence_start:
7269
flags |= TextRowFlags.SENTENCE_START
7370
if range_start_offset == -1 and (not range_start_row.is_in_range or range_start_row.is_range_start):
7471
flags |= TextRowFlags.RANGE_START
75-
new_text_row = TextRow(range_start_row.text_id, range_start_row.ref)
76-
new_text_row.segment = list(range_start_row.segment) + list(row.segment)
77-
new_text_row.flags = flags
78-
rows[len(rows) + range_start_offset] = range_start_ref, new_text_row
72+
new_text_row = TextRow(
73+
range_start_row.text_id,
74+
range_start_row.ref,
75+
segment=list(range_start_row.segment) + list(row.segment),
76+
flags=flags,
77+
)
78+
rows[range_start_offset] = range_start_ref, new_text_row
79+
row = TextRow(row.text_id, row.ref, flags=TextRowFlags.IN_RANGE)
7980
range_start_offset -= 1
8081
else:
8182
range_start_offset = -1
8283
rows.append((ref, row))
83-
if not out_of_order and ref.compare_to(prev_ref) < 0:
84+
if not out_of_order and ref < prev_ref:
8485
out_of_order = True
8586
prev_ref = ref
86-
self._enumerator_has_more_data = bool(self._enumerator.__next__())
87-
if not self._enumerator_has_more_data:
88-
break
87+
self._row = next(self._generator, None)
8988

9089
if out_of_order:
91-
rows.sort(key=lambda tup: tup[0])
90+
rows.sort(key=lambda t: t[0])
9291

9392
for _, row in rows:
94-
self._verse_rows.append(row)
93+
self._verse_rows.put(row)

tests/corpora/test_parallel_text_corpus.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -576,12 +576,12 @@ def test_get_segments_all_source_rows() -> None:
576576
rows = list(parallel_corpus)
577577
assert len(rows) == 7
578578
assert rows[1].source_refs == [2]
579-
assert rows[1].target_refs == []
579+
assert rows[1].target_refs == [2]
580580
assert rows[1].source_segment == "source segment 2 .".split()
581581
assert rows[1].target_segment == []
582582

583583
assert rows[4].source_refs == [5]
584-
assert rows[4].target_refs == []
584+
assert rows[4].target_refs == [5]
585585
assert rows[4].source_segment == "source segment 5 .".split()
586586
assert rows[4].target_segment == []
587587

0 commit comments

Comments
 (0)