Skip to content

Commit c5465a2

Browse files
committed
Fix typing issues in tests
1 parent 89dcd4c commit c5465a2

26 files changed

Lines changed: 273 additions & 229 deletions

.github/workflows/main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ jobs:
6868
run: python -m pip freeze
6969

7070
- name: Run flake8, pylint, mypy
71-
if: matrix.python-version == '3.11'
71+
if: matrix.python-version == '3.14'
7272
run: |
7373
flake8 cmdstanpy test
7474
pylint -v cmdstanpy test
75-
mypy cmdstanpy
75+
mypy cmdstanpy test
7676
7777
- name: CmdStan installation cacheing
7878
id: cache-cmdstan

cmdstanpy/install_cmdstan.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@
2929
import urllib.error
3030
import urllib.request
3131
from collections import OrderedDict
32+
from functools import cached_property
3233
from pathlib import Path
3334
from time import sleep
34-
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
35+
from typing import Any, Callable, Optional, Union
3536

3637
from tqdm.auto import tqdm
3738

@@ -47,18 +48,6 @@
4748

4849
from . import progress as progbar
4950

50-
if sys.version_info >= (3, 8) or TYPE_CHECKING:
51-
# mypy only knows about the new built-in cached_property
52-
from functools import cached_property
53-
else:
54-
# on older Python versions, this is the recommended
55-
# way to get the same effect
56-
from functools import lru_cache
57-
58-
def cached_property(fun):
59-
return property(lru_cache(maxsize=None)(fun))
60-
61-
6251
try:
6352
# on MacOS and Linux, importing this
6453
# improves the UX of the input() function

cmdstanpy/model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from concurrent.futures import ThreadPoolExecutor
1414
from io import StringIO
1515
from multiprocessing import cpu_count
16-
from typing import Any, Callable, Mapping, Optional, TypeVar, Union
16+
from typing import Any, Callable, Mapping, Optional, Sequence, TypeVar, Union
1717

1818
import numpy as np
1919
import pandas as pd
@@ -461,8 +461,7 @@ def sample(
461461
Mapping[str, Any],
462462
float,
463463
str,
464-
list[str],
465-
list[Mapping[str, Any]],
464+
Sequence[Union[str, Mapping[str, Any]]],
466465
None,
467466
] = None,
468467
iter_warmup: Optional[int] = None,
@@ -493,7 +492,7 @@ def sample(
493492
str,
494493
np.ndarray,
495494
Mapping[str, Any],
496-
list[Union[str, np.ndarray, Mapping[str, Any]]],
495+
Sequence[Union[str, np.ndarray, Mapping[str, Any]]],
497496
None,
498497
] = None,
499498
) -> CmdStanMCMC:
@@ -1360,7 +1359,13 @@ def pathfinder(
13601359
calculate_lp: bool = True,
13611360
# arguments standard to all methods
13621361
seed: Optional[int] = None,
1363-
inits: Union[dict[str, float], float, str, os.PathLike, None] = None,
1362+
inits: Union[
1363+
Mapping[str, Any],
1364+
float,
1365+
str,
1366+
Sequence[Union[str, Mapping[str, Any]]],
1367+
None,
1368+
] = None,
13641369
output_dir: OptionalPath = None,
13651370
sig_figs: Optional[int] = None,
13661371
save_profile: bool = False,

cmdstanpy/stanfit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def from_csv(
248248
mode: CmdStanMLE = from_csv(
249249
config_dict['mode'], # type: ignore
250250
method='optimize',
251-
) # type: ignore
251+
)
252252
return CmdStanLaplace(runset, mode=mode)
253253
elif config_dict['method'] == 'pathfinder':
254254
pathfinder_args = PathfinderArgs(

cmdstanpy/stanfit/mcmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def _validate_csv_files(self) -> dict[str, Any]:
348348
save_warmup=self._save_warmup,
349349
thin=self._thin,
350350
)
351-
self._chain_time.append(dzero['time']) # type: ignore
351+
self._chain_time.append(dzero['time'])
352352
if not self._is_fixed_param:
353353
self._divergences[i] = dzero['ct_divergences']
354354
self._max_treedepths[i] = dzero['ct_max_treedepth']
@@ -360,7 +360,7 @@ def _validate_csv_files(self) -> dict[str, Any]:
360360
save_warmup=self._save_warmup,
361361
thin=self._thin,
362362
)
363-
self._chain_time.append(drest['time']) # type: ignore
363+
self._chain_time.append(drest['time'])
364364
for key in dzero:
365365
# check args that matter for parsing, plus name, version
366366
if (

cmdstanpy/stanfit/pathfinder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def is_resampled(self) -> bool:
208208
Returns True if the draws were resampled from several Pathfinder
209209
approximations, False otherwise.
210210
"""
211-
return ( # type: ignore
211+
return (
212212
self._metadata.cmdstan_config.get("num_paths", 4) > 1
213213
and self._metadata.cmdstan_config.get('psis_resample', 1)
214214
in (1, 'true')

cmdstanpy/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def show_versions(output: bool = True) -> str:
8080
deps_info.append((module, None))
8181
else:
8282
try:
83-
ver = mod.__version__ # type: ignore
83+
ver = mod.__version__
8484
deps_info.append((module, ver))
8585
# pylint: disable=broad-except
8686
except Exception:

cmdstanpy/utils/filesystem.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import re
99
import shutil
1010
import tempfile
11-
from typing import Any, Iterator, Mapping, Optional, Union
11+
from typing import Any, Iterator, Mapping, Optional, Sequence, Union
1212

1313
import numpy as np
1414

@@ -131,10 +131,12 @@ def _temp_single_json(
131131

132132

133133
def _temp_multiinput(
134-
input: Union[str, os.PathLike, Mapping[str, Any], list[Any], None],
134+
input: Union[str, os.PathLike, Mapping[str, Any], Sequence[Any], None],
135135
base: int = 1,
136136
) -> Iterator[Optional[str]]:
137-
if isinstance(input, list):
137+
if isinstance(input, Sequence) and not isinstance(
138+
input, (str, os.PathLike)
139+
):
138140
# most complicated case: list of inits
139141
# for multiple chains, we need to create multiple files
140142
# which look like somename_{i}.json and then pass somename.json
@@ -170,7 +172,7 @@ def _temp_multiinput(
170172
@contextlib.contextmanager
171173
def temp_metrics(
172174
metrics: Union[
173-
str, os.PathLike, Mapping[str, Any], np.ndarray, list[Any], None
175+
str, os.PathLike, Mapping[str, Any], np.ndarray, Sequence[Any], None
174176
],
175177
*,
176178
id: int = 1,
@@ -200,7 +202,7 @@ def temp_metrics(
200202
@contextlib.contextmanager
201203
def temp_inits(
202204
inits: Union[
203-
str, os.PathLike, Mapping[str, Any], float, int, list[Any], None
205+
str, os.PathLike, Mapping[str, Any], float, int, Sequence[Any], None
204206
],
205207
*,
206208
allow_multiple: bool = True,
@@ -212,7 +214,9 @@ def temp_inits(
212214
if allow_multiple:
213215
yield from _temp_multiinput(inits, base=id)
214216
else:
215-
if isinstance(inits, list):
217+
if isinstance(inits, Sequence) and not isinstance(
218+
inits, (str, os.PathLike)
219+
):
216220
raise ValueError('Expected single initialization, got list')
217221
yield from _temp_single_json(inits)
218222

cmdstanpy/utils/stancsv.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import re
1010
import warnings
11-
from typing import Any, Iterator, Mapping, Optional, Union
11+
from typing import Any, Iterator, Mapping, Optional, Sequence, Union
1212

1313
import numpy as np
1414
import numpy.typing as npt
@@ -638,11 +638,13 @@ def try_deduce_metric_type(
638638
str,
639639
np.ndarray,
640640
Mapping[str, Any],
641-
list[Union[str, np.ndarray, Mapping[str, Any]]],
641+
Sequence[Union[str, np.ndarray, Mapping[str, Any]]],
642642
],
643643
) -> Optional[str]:
644644
"""Given a user-supplied metric, try to infer the correct metric type."""
645-
if isinstance(inv_metric, list):
645+
if isinstance(inv_metric, Sequence) and not isinstance(
646+
inv_metric, (str, np.ndarray, Mapping)
647+
):
646648
if inv_metric:
647649
inv_metric = inv_metric[0]
648650

test/__init__.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import platform
66
import re
77
from importlib import reload
8-
from typing import Tuple, Type
8+
from types import ModuleType
9+
from typing import Generator, Optional, Tuple, Type
910
from unittest import mock
1011

1112
import pytest
@@ -20,15 +21,17 @@
2021

2122
# pylint: disable=invalid-name
2223
@contextlib.contextmanager
23-
def raises_nested(expected_exception: Type[Exception], match: str) -> None:
24+
def raises_nested(
25+
expected_exception: Type[Exception], match: str
26+
) -> Generator[None, None, None]:
2427
"""A version of assertRaisesRegex that checks the full traceback.
2528
2629
Useful for when an exception is raised from another and you wish to
2730
inspect the inner exception.
2831
"""
2932
with pytest.raises(expected_exception) as ctx:
3033
yield
31-
exception: Exception = ctx.value
34+
exception: Optional[BaseException] = ctx.value
3235
lines = []
3336
while exception:
3437
lines.append(str(exception))
@@ -38,7 +41,9 @@ def raises_nested(expected_exception: Type[Exception], match: str) -> None:
3841

3942

4043
@contextlib.contextmanager
41-
def without_import(library, module):
44+
def without_import(
45+
library: str, module: ModuleType
46+
) -> Generator[None, None, None]:
4247
with mock.patch.dict('sys.modules', {library: None}):
4348
reload(module)
4449
yield
@@ -58,9 +63,13 @@ def check_present(
5863
if isinstance(level, str):
5964
level = getattr(logging, level)
6065
found = any(
61-
logger == logger_ and level == level_ and message.match(message_)
62-
if isinstance(message, re.Pattern)
63-
else message == message_
66+
(
67+
logger == logger_
68+
and level == level_
69+
and message.match(message_)
70+
if isinstance(message, re.Pattern)
71+
else message == message_
72+
)
6473
for logger_, level_, message_ in caplog.record_tuples
6574
)
6675
if not found:

0 commit comments

Comments
 (0)