Skip to content

Commit cb11eeb

Browse files
move unset fields to separate attr
1 parent 78ec2d5 commit cb11eeb

2 files changed

Lines changed: 53 additions & 31 deletions

File tree

mp_api/client/core/client.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,12 +1205,14 @@ def _submit_request_and_process(
12051205
# other sub-urls may use different document models
12061206
# the client does not handle this in a particularly smart way currently
12071207
if self.document_model and use_document_model:
1208-
requested_fields = (
1209-
params.get("_fields", "").split(",")
1210-
if params.get("_fields")
1211-
else None
1208+
data["data"] = self._convert_to_model(
1209+
data["data"],
1210+
requested_fields=(
1211+
params["_fields"].split(",")
1212+
if isinstance(params.get("_fields"), str)
1213+
else None
1214+
),
12121215
)
1213-
data["data"] = self._convert_to_model(data["data"], requested_fields)
12141216

12151217
meta_total_doc_num = data.get("meta", {}).get("total_doc", 1)
12161218

@@ -1260,15 +1262,14 @@ def _convert_to_model(
12601262
# Return empty list if no data in iterator
12611263
return []
12621264
data_model, set_fields, _ = self._generate_returned_model(
1263-
first_doc, requested_fields
1265+
first_doc, requested_fields=requested_fields
12641266
)
12651267

12661268
return [
12671269
data_model(
12681270
**{
1269-
field: value
1270-
for field, value in dict(raw_doc).items()
1271-
if field in set_fields
1271+
field: raw_doc[field]
1272+
for field in set_fields.intersection(raw_doc)
12721273
}
12731274
)
12741275
for raw_doc in (data if is_list else chain([first_doc], data))
@@ -1280,24 +1281,23 @@ def _generate_returned_model(
12801281
self,
12811282
doc: dict[str, Any],
12821283
requested_fields: list[str] | None = None,
1283-
) -> tuple[type[BaseModel], list[str], list[str]]:
1284+
) -> tuple[type[BaseModel], set[str], set[str]]:
12841285
"""Dynamically generates an MPDataDoc Pydantic model from API response content.
12851286
12861287
Args:
12871288
doc (dict): A single document returned from the API
1288-
requested_fields (list[str] or None): Optional list of fields to be returned
1289+
requested_fields (list of str, or None): Optional list of fields to be returned
12891290
12901291
Returns:
1291-
(tuple): Tuple containing (data_model, set_fields, fields_not_requested)
1292+
BaseModel: the pydantic model representing the data
1293+
set of str: fields set in the document model
1294+
set of str: set_fields, fields_not_requested)
12921295
"""
12931296
model_fields = self.document_model.model_fields
1294-
set_fields = [k for k in doc if k in model_fields]
1295-
unset_fields = [field for field in model_fields if field not in set_fields]
1296-
fields_not_requested = (
1297-
[field for field in unset_fields if field not in requested_fields]
1298-
if requested_fields
1299-
else unset_fields
1300-
)
1297+
set_fields = set(doc).intersection(model_fields)
1298+
unset_fields = set(model_fields).difference(set_fields)
1299+
user_requested_fields: list[str] = requested_fields or []
1300+
fields_not_requested = unset_fields.difference(user_requested_fields)
13011301

13021302
# Update with locals() from external module if needed
13031303
if any(
@@ -1324,7 +1324,11 @@ def _generate_returned_model(
13241324
data_model = create_model( # type: ignore
13251325
"MPDataDoc",
13261326
**include_fields,
1327-
fields_not_requested=(list[str], fields_not_requested),
1327+
fields_not_requested=(list[str], list(fields_not_requested)),
1328+
unavailable_fields=(
1329+
list[str],
1330+
list(unset_fields.intersection(user_requested_fields)),
1331+
),
13281332
__base__=_DictLikeAccess,
13291333
__doc__=".".join(
13301334
[
@@ -1354,13 +1358,19 @@ def new_str(self) -> str:
13541358
if n in set_fields
13551359
)
13561360

1357-
s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m\n{extra}\n\n\033[1mFields not requested:\033[0;0m\n{fields_not_requested}" # noqa: E501
1358-
return s
1361+
return (
1362+
f"\033[4m\033[1m{self.__class__.__name__}"
1363+
f"<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m"
1364+
f"\n{extra}\n\n"
1365+
f"\033[1mFields not requested:\033[0;0m\n{fields_not_requested}"
1366+
)
13591367

13601368
def new_getattr(self, attr) -> str:
1369+
if attr in self.unavailable_fields:
1370+
raise AttributeError(f"`{attr}` is unavailable in the returned data.")
13611371
if attr in self.fields_not_requested:
13621372
raise AttributeError(
1363-
f"'{attr}' data is available but has not been requested in 'fields'."
1373+
f"`{attr}` data is available but has not been requested in `fields`."
13641374
" A full list of unrequested fields can be found in `fields_not_requested`."
13651375
)
13661376
else:

tests/client/test_core_client.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,27 @@ def test_available_fields(rester, mpr):
4747
assert rester.available_fields == []
4848

4949

50-
def test_fields_not_requested_excludes_requested_fields(rester: BaseRester):
51-
from emmet.core.tasks import TaskDoc
50+
def test_fields_not_requested_excludes_requested_fields(mpr):
51+
task_rester = mpr.materials.tasks
52+
doc = {"task_id": "fakeid-1234", "state": "successful"}
53+
requested_fields = list(task_rester.document_model.model_fields)
5254

53-
rester.document_model = TaskDoc
54-
doc = {"task_id": "fakeid-1234"}
55-
requested_fields = list(TaskDoc.model_fields.keys())
56-
57-
_, _, fields_not_requested = rester._generate_returned_model(
55+
doc_model, _, fields_not_requested = task_rester._generate_returned_model(
5856
doc, requested_fields=requested_fields
5957
)
58+
deser_doc = doc_model(**doc)
59+
60+
assert "dir_name" not in fields_not_requested
61+
assert "tags" in deser_doc.unavailable_fields
62+
with pytest.raises(AttributeError, match="is unavailable in the returned data"):
63+
deser_doc.tags
6064

61-
assert fields_not_requested == []
65+
requested_fields.remove("tags")
66+
doc_model, _, _ = task_rester._generate_returned_model(
67+
doc, requested_fields=requested_fields
68+
)
69+
deser_doc = doc_model(**doc)
70+
with pytest.raises(
71+
AttributeError, match="data is available but has not been requested in"
72+
):
73+
deser_doc.tags

0 commit comments

Comments
 (0)