Skip to content

Commit def8b0a

Browse files
committed
Address reviewer comments
1 parent b4e079f commit def8b0a

4 files changed

Lines changed: 28 additions & 22 deletions

File tree

machine/corpora/n_parallel_text_corpus.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from contextlib import ExitStack
2-
from typing import Callable, Iterable, List, Optional, Sequence, Set, cast
2+
from typing import Any, Callable, Iterable, List, Optional, Sequence, Set, cast
33

44
from ..scripture.verse_ref import Versification
55
from .n_parallel_text_corpus_base import NParallelTextCorpusBase
66
from .n_parallel_text_row import NParallelTextRow
77
from .scripture_ref import ScriptureRef
88
from .text_corpus import TextCorpus
9-
from .text_corpus_enumerator import _TextCorpusEnumerator
9+
from .text_corpus_enumerator import TextCorpusEnumerator
1010
from .text_row import TextRow, TextRowFlags
1111

1212

@@ -81,7 +81,7 @@ def __init__(
8181
):
8282
self._corpora = corpora
8383
self._row_ref_comparer = row_ref_comparer if row_ref_comparer is not None else default_row_ref_comparer
84-
self.all_rows = [False for _ in range(len(corpora))]
84+
self.all_rows: Sequence[bool] = tuple(False for _ in range(len(corpora)))
8585

8686
def is_tokenized(self, i: int) -> bool:
8787
return self.corpora[i].is_tokenized
@@ -91,7 +91,7 @@ def n(self) -> int:
9191
return len(self.corpora)
9292

9393
@property
94-
def corpora(self) -> List[TextCorpus]:
94+
def corpora(self) -> Sequence[TextCorpus]:
9595
return list(self._corpora)
9696

9797
@property
@@ -116,11 +116,11 @@ def get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Iterable[NParall
116116
filter_text_ids = self._get_text_ids_from_corpora()
117117
if text_ids is not None:
118118
filter_text_ids = filter_text_ids.intersection(text_ids)
119-
enumerated_corpora: List[_TextCorpusEnumerator] = []
119+
enumerated_corpora: List[TextCorpusEnumerator] = []
120120
for i in range(self.n):
121121
generator = iter(self.corpora[i].get_rows(filter_text_ids))
122122
enumerated_corpora.append(
123-
_TextCorpusEnumerator(generator, self.corpora[0].versification, self.corpora[i].versification)
123+
TextCorpusEnumerator(generator, self.corpora[0].versification, self.corpora[i].versification)
124124
)
125125
for row in self._get_rows(enumerated_corpora):
126126
yield row
@@ -142,7 +142,7 @@ def _min_ref_indexes(self, refs: Sequence[object]) -> Sequence[int]:
142142
min_ref_indexes.append(i)
143143
return min_ref_indexes
144144

145-
def _get_rows(self, generators: List[_TextCorpusEnumerator]) -> Iterable[NParallelTextRow]:
145+
def _get_rows(self, generators: List[TextCorpusEnumerator]) -> Iterable[NParallelTextRow]:
146146
with ExitStack() as stack:
147147
iterators = []
148148
for generator in generators:
@@ -274,7 +274,10 @@ def _get_rows(self, generators: List[_TextCorpusEnumerator]) -> Iterable[NParall
274274
def _correct_versification(self, refs: List[object], i: int) -> List[object]:
275275
if any([not c.is_scripture for c in self.corpora]) or len(refs) == 0:
276276
return refs
277-
return [cast(ScriptureRef, ref).change_versification(self.corpora[i].versification) for ref in refs]
277+
return [
278+
cast(ScriptureRef, ref).change_versification(cast(Versification, self.corpora[i].versification))
279+
for ref in refs
280+
]
278281

279282
def _create_rows(
280283
self, range_info: _NRangeInfo, rows: List[Optional[TextRow]], force_in_range: Optional[Sequence[bool]] = None
@@ -377,7 +380,7 @@ def _create_same_ref_rows(
377380
yield r
378381

379382

380-
def default_row_ref_comparer(x: object, y: object) -> int:
383+
def default_row_ref_comparer(x: Any, y: Any) -> int:
381384
# Do not use the default comparer for ScriptureRef, since we want to ignore segments
382385
if isinstance(x, ScriptureRef) and isinstance(y, ScriptureRef):
383386
return x.compare_to(y, False)
@@ -387,6 +390,6 @@ def default_row_ref_comparer(x: object, y: object) -> int:
387390
return -1
388391
if x == y:
389392
return 0
390-
if x < y: # type: ignore
393+
if x < y:
391394
return -1
392395
return 1

machine/corpora/n_parallel_text_corpus_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Sequence
2+
from typing import Iterable, Sequence
33

44
from .corpus import Corpus
55
from .n_parallel_text_row import NParallelTextRow
@@ -14,7 +14,7 @@ def n(self) -> int: ...
1414

1515
@property
1616
@abstractmethod
17-
def corpora(self) -> List[TextCorpus]: ...
17+
def corpora(self) -> Sequence[TextCorpus]: ...
1818

1919
@abstractmethod
20-
def get_rows(self, text_ids: List[str]) -> Sequence[NParallelTextRow]: ...
20+
def get_rows(self, text_ids: Iterable[str]) -> Sequence[NParallelTextRow]: ...

machine/corpora/text_corpus.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def is_tokenized(self) -> bool: ...
2828

2929
@property
3030
@abstractmethod
31-
def versification(self) -> Versification: ...
31+
def versification(self) -> Optional[Versification]: ...
3232

3333
def get_rows(self, text_ids: Optional[Iterable[str]] = None) -> ContextManagedGenerator[TextRow, None, None]:
3434
return ContextManagedGenerator(self._get_rows(text_ids))
@@ -177,7 +177,7 @@ def is_tokenized(self) -> bool:
177177
return self._is_tokenized
178178

179179
@property
180-
def versification(self) -> Versification:
180+
def versification(self) -> Optional[Versification]:
181181
return self._corpus.versification
182182

183183
def count(self, include_empty: bool = True, text_ids: Optional[Iterable[str]] = None) -> int:
@@ -202,7 +202,7 @@ def is_tokenized(self) -> bool:
202202
return self._corpus.is_tokenized
203203

204204
@property
205-
def versification(self) -> Versification:
205+
def versification(self) -> Optional[Versification]:
206206
return self._corpus.versification
207207

208208
def _get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Generator[TextRow, None, None]:
@@ -224,7 +224,7 @@ def is_tokenized(self) -> bool:
224224
return self._corpus.is_tokenized
225225

226226
@property
227-
def versification(self) -> Versification:
227+
def versification(self) -> Optional[Versification]:
228228
return self._corpus.versification
229229

230230
def _get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Generator[TextRow, None, None]:
@@ -246,7 +246,7 @@ def is_tokenized(self) -> bool:
246246
return self._corpus.is_tokenized
247247

248248
@property
249-
def versification(self) -> Versification:
249+
def versification(self) -> Optional[Versification]:
250250
return self._corpus.versification
251251

252252
def _get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Generator[TextRow, None, None]:
@@ -268,7 +268,7 @@ def is_tokenized(self) -> bool:
268268
return self._corpus.is_tokenized
269269

270270
@property
271-
def versification(self) -> Versification:
271+
def versification(self) -> Optional[Versification]:
272272
return self._corpus.versification
273273

274274
def _get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Generator[TextRow, None, None]:

machine/corpora/text_corpus_enumerator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
from .text_row import TextRow, TextRowFlags
77

88

9-
class _TextCorpusEnumerator(ContextManager["_TextCorpusEnumerator"], Generator[TextRow, None, None]):
9+
class TextCorpusEnumerator(ContextManager["TextCorpusEnumerator"], Generator[TextRow, None, None]):
1010
def __init__(
11-
self, generator: Generator[TextRow, None, None], ref_versification: Versification, versification: Versification
11+
self,
12+
generator: Generator[TextRow, None, None],
13+
ref_versification: Optional[Versification],
14+
versification: Optional[Versification],
1215
):
1316
self._generator = generator
1417
self._ref_versification = ref_versification
@@ -42,7 +45,7 @@ def close(self) -> None:
4245
super().close()
4346
self._generator.close()
4447

45-
def __enter__(self) -> "_TextCorpusEnumerator":
48+
def __enter__(self) -> "TextCorpusEnumerator":
4649
return self
4750

4851
def __exit__(self, type: Any, value: Any, traceback: Any) -> None:

0 commit comments

Comments
 (0)