Skip to content

Commit 2d0ffff

Browse files
fix: exclude requested fields from fields_not_requested (#1066)
2 parents 87e3a79 + cb11eeb commit 2d0ffff

2 files changed

Lines changed: 75 additions & 16 deletions

File tree

mp_api/client/core/client.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,14 @@ def _submit_request_and_process(
12051205
# other sub-urls may use different document models
12061206
# the client does not handle this in a particularly smart way currently
12071207
if self.document_model and use_document_model:
1208-
data["data"] = self._convert_to_model(data["data"])
1208+
data["data"] = self._convert_to_model(
1209+
data["data"],
1210+
requested_fields=(
1211+
params["_fields"].split(",")
1212+
if isinstance(params.get("_fields"), str)
1213+
else None
1214+
),
1215+
)
12091216

12101217
meta_total_doc_num = data.get("meta", {}).get("total_doc", 1)
12111218

@@ -1234,11 +1241,13 @@ def _submit_request_and_process(
12341241
def _convert_to_model(
12351242
self,
12361243
data: list[dict[str, Any]] | Iterator,
1244+
requested_fields: list[str] | None = None,
12371245
) -> list[BaseModel] | list[dict[str, Any]]:
12381246
"""Converts dictionary documents to instantiated MPDataDoc objects.
12391247
12401248
Args:
12411249
data (list[dict] or Iterator): Raw dictionary data objects
1250+
requested_fields (list[str] or None): Optional list of fields to be returned
12421251
12431252
Returns:
12441253
(list[MPDataDoc]): List of MPDataDoc objects
@@ -1252,14 +1261,15 @@ def _convert_to_model(
12521261
except StopIteration:
12531262
# Return empty list if no data in iterator
12541263
return []
1255-
data_model, set_fields, _ = self._generate_returned_model(first_doc)
1264+
data_model, set_fields, _ = self._generate_returned_model(
1265+
first_doc, requested_fields=requested_fields
1266+
)
12561267

12571268
return [
12581269
data_model(
12591270
**{
1260-
field: value
1261-
for field, value in dict(raw_doc).items()
1262-
if field in set_fields
1271+
field: raw_doc[field]
1272+
for field in set_fields.intersection(raw_doc)
12631273
}
12641274
)
12651275
for raw_doc in (data if is_list else chain([first_doc], data))
@@ -1268,11 +1278,26 @@ def _convert_to_model(
12681278
return data
12691279

12701280
def _generate_returned_model(
1271-
self, doc: dict[str, Any]
1272-
) -> tuple[type[BaseModel], list[str], list[str]]:
1281+
self,
1282+
doc: dict[str, Any],
1283+
requested_fields: list[str] | None = None,
1284+
) -> tuple[type[BaseModel], set[str], set[str]]:
1285+
"""Dynamically generates an MPDataDoc Pydantic model from API response content.
1286+
1287+
Args:
1288+
doc (dict): A single document returned from the API
1289+
requested_fields (list of str, or None): Optional list of fields to be returned
1290+
1291+
Returns:
1292+
BaseModel: the pydantic model representing the data
1293+
set of str: fields set in the document model
1294+
set of str: set_fields, fields_not_requested)
1295+
"""
12731296
model_fields = self.document_model.model_fields
1274-
set_fields = [k for k in doc if k in model_fields]
1275-
unset_fields = [field for field in model_fields if field not in set_fields]
1297+
set_fields = set(doc).intersection(model_fields)
1298+
unset_fields = set(model_fields).difference(set_fields)
1299+
user_requested_fields: list[str] = requested_fields or []
1300+
fields_not_requested = unset_fields.difference(user_requested_fields)
12761301

12771302
# Update with locals() from external module if needed
12781303
if any(
@@ -1299,9 +1324,11 @@ def _generate_returned_model(
12991324
data_model = create_model( # type: ignore
13001325
"MPDataDoc",
13011326
**include_fields,
1302-
# TODO fields_not_requested is not the same as unset_fields
1303-
# i.e. field could be requested but not available in the raw doc
1304-
fields_not_requested=(list[str], unset_fields),
1327+
fields_not_requested=(list[str], list(fields_not_requested)),
1328+
unavailable_fields=(
1329+
list[str],
1330+
list(unset_fields.intersection(user_requested_fields)),
1331+
),
13051332
__base__=_DictLikeAccess,
13061333
__doc__=".".join(
13071334
[
@@ -1331,13 +1358,19 @@ def new_str(self) -> str:
13311358
if n in set_fields
13321359
)
13331360

1334-
s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m\n{extra}\n\n\033[1mFields not requested:\033[0;0m\n{unset_fields}" # noqa: E501
1335-
return s
1361+
return (
1362+
f"\033[4m\033[1m{self.__class__.__name__}"
1363+
f"<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m"
1364+
f"\n{extra}\n\n"
1365+
f"\033[1mFields not requested:\033[0;0m\n{fields_not_requested}"
1366+
)
13361367

13371368
def new_getattr(self, attr) -> str:
1369+
if attr in self.unavailable_fields:
1370+
raise AttributeError(f"`{attr}` is unavailable in the returned data.")
13381371
if attr in self.fields_not_requested:
13391372
raise AttributeError(
1340-
f"'{attr}' data is available but has not been requested in 'fields'."
1373+
f"`{attr}` data is available but has not been requested in `fields`."
13411374
" A full list of unrequested fields can be found in `fields_not_requested`."
13421375
)
13431376
else:
@@ -1354,7 +1387,7 @@ def new_dict(self, *args, **kwargs):
13541387
data_model.__getattr__ = new_getattr
13551388
data_model.dict = new_dict
13561389

1357-
return data_model, set_fields, unset_fields
1390+
return data_model, set_fields, fields_not_requested
13581391

13591392
def _query_resource_data(
13601393
self,

tests/client/test_core_client.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,29 @@ def test_count(mpr):
4545
def test_available_fields(rester, mpr):
4646
assert len(mpr.materials.available_fields) > 0
4747
assert rester.available_fields == []
48+
49+
50+
def test_fields_not_requested_excludes_requested_fields(mpr):
51+
task_rester = mpr.materials.tasks
52+
doc = {"task_id": "fakeid-1234", "state": "successful"}
53+
requested_fields = list(task_rester.document_model.model_fields)
54+
55+
doc_model, _, fields_not_requested = task_rester._generate_returned_model(
56+
doc, requested_fields=requested_fields
57+
)
58+
deser_doc = doc_model(**doc)
59+
60+
assert "dir_name" not in fields_not_requested
61+
assert "tags" in deser_doc.unavailable_fields
62+
with pytest.raises(AttributeError, match="is unavailable in the returned data"):
63+
deser_doc.tags
64+
65+
requested_fields.remove("tags")
66+
doc_model, _, _ = task_rester._generate_returned_model(
67+
doc, requested_fields=requested_fields
68+
)
69+
deser_doc = doc_model(**doc)
70+
with pytest.raises(
71+
AttributeError, match="data is available but has not been requested in"
72+
):
73+
deser_doc.tags

0 commit comments

Comments
 (0)