Skip to content

Commit 10ac9a9

Browse files
committed
fix: exclude requested_fields from fields_not_requested
1 parent 87e3a79 commit 10ac9a9

2 files changed

Lines changed: 46 additions & 8 deletions

File tree

mp_api/client/core/client.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,12 @@ def _submit_request_and_process(
12051205
# other sub-urls may use different document models
12061206
# the client does not handle this in a particularly smart way currently
12071207
if self.document_model and use_document_model:
1208-
data["data"] = self._convert_to_model(data["data"])
1208+
requested_fields = (
1209+
params.get("_fields", "").split(",")
1210+
if params.get("_fields")
1211+
else None
1212+
)
1213+
data["data"] = self._convert_to_model(data["data"], requested_fields)
12091214

12101215
meta_total_doc_num = data.get("meta", {}).get("total_doc", 1)
12111216

@@ -1234,11 +1239,13 @@ def _submit_request_and_process(
12341239
def _convert_to_model(
12351240
self,
12361241
data: list[dict[str, Any]] | Iterator,
1242+
requested_fields: list[str] | None = None,
12371243
) -> list[BaseModel] | list[dict[str, Any]]:
12381244
"""Converts dictionary documents to instantiated MPDataDoc objects.
12391245
12401246
Args:
12411247
data (list[dict] or Iterator): Raw dictionary data objects
1248+
requested_fields (list[str] or None): Optional list of fields to be returned
12421249
12431250
Returns:
12441251
(list[MPDataDoc]): List of MPDataDoc objects
@@ -1252,7 +1259,9 @@ def _convert_to_model(
12521259
except StopIteration:
12531260
# Return empty list if no data in iterator
12541261
return []
1255-
data_model, set_fields, _ = self._generate_returned_model(first_doc)
1262+
data_model, set_fields, _ = self._generate_returned_model(
1263+
first_doc, requested_fields
1264+
)
12561265

12571266
return [
12581267
data_model(
@@ -1268,11 +1277,27 @@ def _convert_to_model(
12681277
return data
12691278

12701279
def _generate_returned_model(
1271-
self, doc: dict[str, Any]
1280+
self,
1281+
doc: dict[str, Any],
1282+
requested_fields: list[str] | None = None,
12721283
) -> tuple[type[BaseModel], list[str], list[str]]:
1284+
"""Dynamically generates an MPDataDoc Pydantic model from API response content.
1285+
1286+
Args:
1287+
doc (dict): A single document returned from the API
1288+
requested_fields (list[str] or None): Optional list of fields to be returned
1289+
1290+
Returns:
1291+
(tuple): Tuple containing (data_model, set_fields, fields_not_requested)
1292+
"""
12731293
model_fields = self.document_model.model_fields
12741294
set_fields = [k for k in doc if k in model_fields]
12751295
unset_fields = [field for field in model_fields if field not in set_fields]
1296+
fields_not_requested = (
1297+
[field for field in unset_fields if field not in requested_fields]
1298+
if requested_fields
1299+
else unset_fields
1300+
)
12761301

12771302
# Update with locals() from external module if needed
12781303
if any(
@@ -1299,9 +1324,7 @@ def _generate_returned_model(
12991324
data_model = create_model( # type: ignore
13001325
"MPDataDoc",
13011326
**include_fields,
1302-
# TODO fields_not_requested is not the same as unset_fields
1303-
# i.e. field could be requested but not available in the raw doc
1304-
fields_not_requested=(list[str], unset_fields),
1327+
fields_not_requested=(list[str], fields_not_requested),
13051328
__base__=_DictLikeAccess,
13061329
__doc__=".".join(
13071330
[
@@ -1331,7 +1354,7 @@ def new_str(self) -> str:
13311354
if n in set_fields
13321355
)
13331356

1334-
s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m\n{extra}\n\n\033[1mFields not requested:\033[0;0m\n{unset_fields}" # noqa: E501
1357+
s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m\n{extra}\n\n\033[1mFields not requested:\033[0;0m\n{fields_not_requested}" # noqa: E501
13351358
return s
13361359

13371360
def new_getattr(self, attr) -> str:
@@ -1354,7 +1377,7 @@ def new_dict(self, *args, **kwargs):
13541377
data_model.__getattr__ = new_getattr
13551378
data_model.dict = new_dict
13561379

1357-
return data_model, set_fields, unset_fields
1380+
return data_model, set_fields, fields_not_requested
13581381

13591382
def _query_resource_data(
13601383
self,

tests/client/test_core_client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,18 @@ def test_count(mpr):
4545
def test_available_fields(rester, mpr):
4646
assert len(mpr.materials.available_fields) > 0
4747
assert rester.available_fields == []
48+
49+
50+
def test_fields_not_requested_excludes_requested_fields(rester: BaseRester):
51+
from emmet.core.tasks import TaskDoc
52+
53+
rester.document_model = TaskDoc
54+
doc = {"task_id": "fakeid-1234", "state": "successful"}
55+
requested_fields = list(TaskDoc.model_fields.keys())
56+
57+
_, _, fields_not_requested = rester._generate_returned_model(
58+
doc, requested_fields=requested_fields
59+
)
60+
61+
assert "dir_name" not in fields_not_requested
62+
assert "tags" not in fields_not_requested

0 commit comments

Comments
 (0)