Skip to content

Commit c58bf79

Browse files
precommit / mypy
1 parent b720db4 commit c58bf79

5 files changed

Lines changed: 19 additions & 24 deletions

File tree

mp_api/client/_server_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_request_headers() -> dict[str, Any]:
4242

4343

4444
def is_dev_env(
45-
localhosts : Sequence[str] = ("localhost:", "127.0.0.1:", "0.0.0.0:")
45+
localhosts: Sequence[str] = ("localhost:", "127.0.0.1:", "0.0.0.0:")
4646
) -> bool:
4747
"""Determine if current env is local/developmental or production.
4848
@@ -56,9 +56,7 @@ def is_dev_env(
5656
return (
5757
True
5858
if not has_request_context()
59-
else get_request_headers()
60-
.get("Host", "")
61-
.startswith(localhosts)
59+
else get_request_headers().get("Host", "").startswith(localhosts)
6260
)
6361

6462

mp_api/client/core/settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ class MAPIClientSettings(BaseSettings):
7272
description="Angle tolerance for structure matching in degrees.",
7373
)
7474

75-
LOG_FILE : Path = Field(
75+
LOG_FILE: Path = Field(
7676
Path("~/.mprester.log.yaml").expanduser(),
77-
description = "Path for storing last accessed database version."
77+
description="Path for storing last accessed database version.",
7878
)
7979

8080
LOCAL_DATASET_CACHE: Path = Field(

mp_api/client/mprester.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def __dir__(self):
300300

301301
def __repr__(self) -> str:
302302
db_version = self.get_database_version()
303-
return f"MPRester({'v' + db_version if db_version else "unknown version"})"
303+
return f"MPRester({'v' + db_version if db_version else 'unknown version'})"
304304

305305
def get_task_ids_associated_with_material_id(
306306
self, material_id: str, calc_types: list[CalcType] | None = None
@@ -1622,21 +1622,19 @@ def get_oxygen_evolution(
16221622

16231623
def _db_version_check(self) -> None:
16241624
"""Check if the database version has drifted."""
1625+
import yaml # type: ignore[import-untyped]
16251626

1626-
import yaml
16271627
db_version = self.get_database_version()
16281628
old_db_version = None
16291629
if MAPI_CLIENT_SETTINGS.LOG_FILE.exists():
16301630
old_db_version = (
1631-
yaml.safe_load(
1632-
MAPI_CLIENT_SETTINGS.LOG_FILE.read_text()
1633-
) or {}
1634-
).get("MAPI_DB_VERSION",None)
1631+
yaml.safe_load(MAPI_CLIENT_SETTINGS.LOG_FILE.read_text()) or {}
1632+
).get("MAPI_DB_VERSION", None)
16351633

16361634
# Handle legacy pymatgen behavior
1637-
if not isinstance(old_db_version,str):
1635+
if not isinstance(old_db_version, str):
16381636
old_db_version = None
1639-
1637+
16401638
if old_db_version != db_version:
16411639
MAPI_CLIENT_SETTINGS.LOG_FILE.write_text(
16421640
yaml.safe_dump({"MAPI_DB_VERSION": db_version})
@@ -1648,4 +1646,4 @@ def _db_version_check(self) -> None:
16481646
f"from v{old_db_version} to v{db_version}.",
16491647
category=MPRestWarning,
16501648
stacklevel=2,
1651-
)
1649+
)

mp_api/client/routes/_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def create_user_settings(
146146
Returns:
147147
Dictionary with consumer_id and write status.
148148
"""
149-
return self._post_resource(
149+
return self._post_resource( # type: ignore[return-value]
150150
body=settings, params={"consumer_id": consumer_id}
151151
).get("data")
152152

@@ -216,6 +216,6 @@ def get_user_settings(
216216
Raises:
217217
MPRestError.
218218
"""
219-
return self._query_resource(
219+
return self._query_resource( # type: ignore[return-value]
220220
suburl=f"{consumer_id}", fields=fields, num_chunks=1, chunk_size=1
221221
).get("data")

tests/client/test_mprester.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -627,19 +627,18 @@ def test_monty_decode_warning(self):
627627
MPRester(monty_decode=False)
628628

629629
def test_db_warning(self, monkeypatch: pytest.MonkeyPatch):
630-
631630
from pathlib import Path
632631
import yaml
633632
from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS
634633

635634
with NamedTemporaryFile(suffix=".yaml") as tmp_log:
636-
monkeypatch.setattr(MAPI_CLIENT_SETTINGS,"LOG_FILE",Path(tmp_log.name))
635+
monkeypatch.setattr(MAPI_CLIENT_SETTINGS, "LOG_FILE", Path(tmp_log.name))
637636

638-
with MPRester(notify_db_version = True) as mpr:
637+
with MPRester(notify_db_version=True) as mpr:
639638
db_version = mpr.get_database_version()
640639

641-
parsed_db_ver = yaml.safe_load(
642-
Path(tmp_log.name).read_text()
643-
).get("MAPI_DB_VERSION")
640+
parsed_db_ver = yaml.safe_load(Path(tmp_log.name).read_text()).get(
641+
"MAPI_DB_VERSION"
642+
)
644643
assert parsed_db_ver == db_version
645-
assert isinstance(parsed_db_ver,str)
644+
assert isinstance(parsed_db_ver, str)

0 commit comments

Comments
 (0)