Skip to content

Commit ecae82a

Browse files
move logic for setting consumer/api header stuff to mprester
1 parent beb78dd commit ecae82a

3 files changed

Lines changed: 20 additions & 33 deletions

File tree

mp_api/client/_server_utils.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
"""Define utilities needed by the MP web server."""
1+
"""Define flask-dependent utilities for the web server."""
2+
23
from __future__ import annotations
34

45
try:
@@ -10,7 +11,6 @@
1011
_has_request_context = None
1112
request = None
1213

13-
from mp_api.client import MPRester
1414
from mp_api.client.core.utils import validate_api_key
1515

1616
def has_request_context() -> bool:
@@ -34,8 +34,8 @@ def get_request_headers() -> dict[str,Any]:
3434
"""
3535
return request.headers if has_request_context() else {}
3636

37-
def is_localhost() -> bool:
38-
"""Determine if current env is local or production.
37+
def is_dev_env() -> bool:
38+
"""Determine if current env is local/developmental or production.
3939
4040
Returns:
4141
bool: True if the environment is locally hosted.
@@ -83,39 +83,24 @@ def is_logged_in_user(consumer: dict[str, str] | None = None) -> bool:
8383
return bool(not c.get("X-Anonymous-Consumer") and c.get("X-Consumer-Id"))
8484

8585

86-
def get_user_api_key(consumer: dict[str, str] | None = None) -> str | None:
86+
def get_user_api_key(
87+
api_key : str | None = None,
88+
consumer: dict[str, str] | None = None
89+
) -> str | None:
8790
"""Get the api key that belongs to the current user.
8891
8992
If running on localhost, api key is obtained from
9093
the environment variable MP_API_KEY.
9194
9295
Args:
96+
api_key (str or None) : User API key
9397
consumer (dict of str to str, or None): Headers associated with the consumer
9498
9599
Returns:
96100
str, the API key, or None if no API key could be identified.
97101
"""
98-
c = consumer or get_consumer()
99-
100-
if is_localhost():
101-
return validate_api_key()
102-
elif is_logged_in_user(c):
102+
if is_dev_env():
103+
return validate_api_key(api_key=api_key)
104+
elif is_logged_in_user(c := consumer or get_consumer()):
103105
return c.get("X-Consumer-Custom-Id")
104-
return None
105-
106-
107-
def get_rester(**kwargs) -> MPRester:
108-
"""Create MPRester with headers set for localhost and production compatibility.
109-
110-
Args:
111-
**kwargs : kwargs to pass to MPRester
112-
113-
Returns:
114-
MPRester
115-
"""
116-
if is_localhost():
117-
dev_api_key = get_user_api_key()
118-
SESSION.headers["x-api-key"] = dev_api_key or ""
119-
return MPRester(api_key=dev_api_key, session=SESSION, **kwargs)
120-
121-
return MPRester(headers=get_consumer(), session=SESSION, **kwargs)
106+
return None

mp_api/client/mprester.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
2222
from requests import Session, get
2323

24+
from mp_api.client._server_utils import get_consumer, get_user_api_key, is_dev_env
2425
from mp_api.client.core import BaseRester
2526
from mp_api.client.core._oxygen_evolution import OxygenEvolution
2627
from mp_api.client.core.exceptions import (
@@ -32,7 +33,6 @@
3233
from mp_api.client.core.utils import (
3334
LazyImport,
3435
load_json,
35-
validate_api_key,
3636
validate_endpoint,
3737
validate_ids,
3838
)
@@ -141,16 +141,18 @@ def __init__(
141141
force_renew: Option to overwrite existing local dataset
142142
**kwargs: access to legacy kwargs that may be in the process of being deprecated
143143
"""
144-
self.api_key = validate_api_key(api_key)
144+
self.api_key = get_user_api_key(api_key=api_key)
145145

146146
self.endpoint = validate_endpoint(endpoint)
147147

148-
self.headers = headers or {}
148+
self.headers = headers or get_consumer()
149149
self.session = session or BaseRester._create_session(
150150
api_key=self.api_key,
151151
include_user_agent=include_user_agent,
152152
headers=self.headers,
153153
)
154+
if is_dev_env():
155+
self.session.headers["x-api-key"] = self.api_key
154156
self._include_user_agent = include_user_agent
155157
self.use_document_model = use_document_model
156158
self.mute_progress_bars = mute_progress_bars

tests/client/core/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,15 @@ def test_api_key_validation(monkeypatch: pytest.MonkeyPatch):
142142
monkeypatch.setattr(pymatgen.core, "SETTINGS", non_api_key_settings)
143143

144144
with pytest.raises(MPRestError, match="32 characters"):
145-
validate_api_key("invalid_key")
145+
validate_api_key(api_key="invalid_key")
146146

147147
with pytest.warns(MPRestWarning, match="No API key found"):
148148
validate_api_key()
149149

150150
junk_api_key = "a" * 32
151151
monkeypatch.setenv("MP_API_KEY", junk_api_key)
152152
assert validate_api_key() == junk_api_key
153-
assert validate_api_key(junk_api_key) == junk_api_key
153+
assert validate_api_key(api_key=junk_api_key) == junk_api_key
154154

155155
other_junk_api_key = "b" * 32
156156
monkeypatch.setattr(

0 commit comments

Comments
 (0)