Skip to content

Commit cd1718e

Browse files
reorg tests slightly
2 parents 28993f8 + 5994c58 commit cd1718e

18 files changed

Lines changed: 260 additions & 1474 deletions

mp_api/client/core/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def __init__(
191191
warnings.warn(
192192
"Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`."
193193
"The client by default returns results consistent with `monty_decode=True`.",
194-
category=DeprecationWarning,
194+
category=MPRestWarning,
195195
stacklevel=2,
196196
)
197197

@@ -1349,7 +1349,7 @@ def new_str(self) -> str:
13491349

13501350
return (
13511351
f"\033[4m\033[1m{self.__class__.__name__}"
1352-
f"<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m"
1352+
f"<{orig_rester_name}>\033[0;0m\033[0;0m"
13531353
f"\n{extra}\n\n"
13541354
f"\033[1mFields not requested:\033[0;0m\n{fields_not_requested}"
13551355
)
@@ -1597,7 +1597,7 @@ def __getattr__(self, v: str):
15971597
self.sub_resters[v](
15981598
api_key=self.api_key,
15991599
endpoint=self.base_endpoint,
1600-
include_user_agent=self._include_user_agent,
1600+
include_user_agent=self.include_user_agent,
16011601
session=self.session,
16021602
use_document_model=self.use_document_model,
16031603
headers=self.headers,
@@ -1606,6 +1606,7 @@ def __getattr__(self, v: str):
16061606
force_renew=self.force_renew,
16071607
)
16081608
return self.sub_resters[v]
1609+
raise AttributeError(f"{self.__class__} has no attribute {v}")
16091610

16101611
def __dir__(self):
16111612
return dir(self.__class__) + list(self._sub_resters)

mp_api/client/mprester.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,25 +1076,28 @@ def get_entries_in_chemsys(
10761076
if isinstance(elements, str):
10771077
elements = elements.split("-")
10781078

1079-
elements_set = set(elements) # remove duplicate elements
1079+
# 9 elements would be sum_{i=1}^{9} (9 choose i) = 511
1080+
# From testing, this is the highest number of chemsys
1081+
# we can query before URI lengths are exceeded
1082+
if len(elements_set := set(elements)) > 9: # remove duplicate elements
1083+
raise MPRestError(
1084+
"Please specify fewer elements to query by, "
1085+
"or identify a subset of relevant chemical systems to query first."
1086+
)
10801087

10811088
all_chemsyses = [
10821089
"-".join(sorted(els))
10831090
for i in range(len(elements_set))
10841091
for els in itertools.combinations(elements_set, i + 1)
10851092
]
10861093

1087-
entries = []
1088-
1089-
entries.extend(
1090-
self.get_entries(
1091-
all_chemsyses,
1092-
compatible_only=compatible_only,
1093-
property_data=property_data,
1094-
conventional_unit_cell=conventional_unit_cell,
1095-
additional_criteria=additional_criteria or DEFAULT_THERMOTYPE_CRITERIA,
1096-
**kwargs,
1097-
)
1094+
entries = self.get_entries(
1095+
all_chemsyses,
1096+
compatible_only=compatible_only,
1097+
property_data=property_data,
1098+
conventional_unit_cell=conventional_unit_cell,
1099+
additional_criteria=additional_criteria or DEFAULT_THERMOTYPE_CRITERIA,
1100+
**kwargs,
10981101
)
10991102

11001103
if use_gibbs:
@@ -1261,21 +1264,41 @@ def get_charge_density_from_material_id(
12611264
task_id = latest_doc["task_id"]
12621265
return self.get_charge_density_from_task_id(task_id, inc_task_doc)
12631266

1264-
def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
1267+
def get_download_info(
1268+
self,
1269+
material_ids: str | MPID | list[str | MPID],
1270+
calc_types: list[str | CalcType] | None = None,
1271+
file_patterns: list[str] | None = None,
1272+
):
12651273
"""Get a list of URLs to retrieve raw VASP output files from the NoMaD repository
12661274
Args:
1267-
material_ids (list): list of material identifiers (mp-id's)
1268-
task_types (list): list of task types to include in download (see CalcType Enum class)
1275+
material_ids (str or MPID, or list thereof): list of material identifiers (mp-id's)
1276+
calc_types (list of str or CalcType): list of calc types to include in download (see CalcType Enum class)
12691277
file_patterns (list): list of wildcard file names to include for each task
12701278
Returns:
12711279
a tuple of 1) a dictionary mapping material_ids to task_ids and
12721280
calc_types, and 2) a list of URLs to download zip archives from
12731281
NoMaD repository. Each zip archive will contain a manifest.json with
12741282
metadata info, e.g. the task/external_ids that belong to a directory.
12751283
"""
1284+
warnings.warn(
1285+
"Full downloads of raw data are being transitioned to "
1286+
"Materials Project's AWS S3 OpenData buckets. "
1287+
"These features for accessing legacy raw data via NOMAD "
1288+
"are maintained but may not be supported in the future.",
1289+
category=MPRestWarning,
1290+
stacklevel=2,
1291+
)
1292+
12761293
# task_id's correspond to NoMaD external_id's
1294+
if isinstance(material_ids, str | MPID):
1295+
material_ids = [material_ids]
1296+
12771297
calc_types = (
1278-
[t.value for t in calc_types if isinstance(t, CalcType)]
1298+
[
1299+
t.value if isinstance(t, CalcType) else CalcType(t).value
1300+
for t in calc_types
1301+
]
12791302
if calc_types
12801303
else []
12811304
)

mp_api/client/routes/materials/materials.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from pathlib import Path
34
from typing import TYPE_CHECKING
45

56
from emmet.core.symmetry import CrystalSystem
@@ -172,11 +173,11 @@ def search(
172173

173174
def find_structure(
174175
self,
175-
filename_or_structure,
176+
filename_or_structure: str | Path | Structure,
176177
ltol=MAPI_CLIENT_SETTINGS.LTOL,
177178
stol=MAPI_CLIENT_SETTINGS.STOL,
178179
angle_tol=MAPI_CLIENT_SETTINGS.ANGLE_TOL,
179-
allow_multiple_results=False,
180+
allow_multiple_results: bool | int = False,
180181
) -> list[str] | str:
181182
"""Finds matching structures from the Materials Project database.
182183
@@ -186,48 +187,75 @@ def find_structure(
186187
default tolerances.
187188
188189
Args:
189-
filename_or_structure: filename or Structure object
190+
filename_or_structure: filename as a str or Path, or a Structure object
190191
ltol: fractional length tolerance
191192
stol: site tolerance
192193
angle_tol: angle tolerance in degrees
193-
allow_multiple_results: changes return type for either
194-
a single material_id or list of material_ids
194+
allow_multiple_results (bool or int): changes return type for either
195+
a single material_id or list of material_ids.
196+
If a bool, returns either all matches (True) or one match at most (False).
197+
If an int, returns that many matches at most.
198+
195199
Returns:
196200
A matching material_id if one is found or list of results if allow_multiple_results
197201
is True
198202
Raises:
199203
MPRestError
200204
"""
201-
params = {"ltol": ltol, "stol": stol, "angle_tol": angle_tol, "_limit": 1}
205+
from pymatgen.analysis.structure_matcher import (
206+
ElementComparator,
207+
StructureMatcher,
208+
)
202209

203-
if isinstance(filename_or_structure, str):
210+
if (
211+
isinstance(filename_or_structure, str | Path)
212+
and Path(filename_or_structure).exists()
213+
):
204214
s = Structure.from_file(filename_or_structure)
205215
elif isinstance(filename_or_structure, Structure):
206216
s = filename_or_structure
207217
else:
208218
raise MPRestError("Provide filename or Structure object.")
209219

210-
results = self._post_resource(
211-
body=s.as_dict(),
212-
params=params,
213-
suburl="find_structure",
214-
use_document_model=False,
215-
).get("data")
216-
217-
if not results:
220+
mat_docs = self.search(
221+
formula=s.reduced_formula, fields=["material_id", "structure"]
222+
)
223+
if not mat_docs:
218224
return []
219225

220-
material_ids = validate_ids([doc["material_id"] for doc in results])
226+
if isinstance(allow_multiple_results, bool):
227+
max_matches: int = len(mat_docs) if allow_multiple_results else 1
228+
elif isinstance(allow_multiple_results, int):
229+
max_matches = allow_multiple_results
230+
else:
231+
raise MPRestError(
232+
f"`allow_multiple_results` must be a bool or int, not {type(allow_multiple_results)}"
233+
)
221234

222-
if len(material_ids) > 1: # type: ignore
223-
if not allow_multiple_results:
224-
raise ValueError(
225-
"Multiple matches found for this combination of tolerances, but "
226-
"`allow_multiple_results` set to False."
227-
)
228-
return material_ids # type: ignore
235+
matcher = StructureMatcher(
236+
ltol=ltol,
237+
stol=stol,
238+
angle_tol=angle_tol,
239+
primitive_cell=True,
240+
scale=True,
241+
attempt_supercell=False,
242+
comparator=ElementComparator(),
243+
)
229244

230-
return material_ids[0]
245+
matches: list[str] = []
246+
for doc in mat_docs:
247+
if matcher.fit(
248+
s,
249+
doc.structure if self.use_document_model else Structure.from_dict(doc["structure"]), # type: ignore
250+
):
251+
matches.append(doc.material_id.string if self.use_document_model else doc["material_id"]) # type: ignore
252+
if len(matches) >= max_matches:
253+
break
254+
255+
if not matches:
256+
return []
257+
material_ids = validate_ids(matches)
258+
return material_ids if allow_multiple_results else material_ids[0]
231259

232260
def get_blessed_entries(
233261
self,

mp_api/client/routes/molecules/bonds.py

Lines changed: 0 additions & 134 deletions
This file was deleted.

0 commit comments

Comments
 (0)