Skip to content

Commit b252ea3

Browse files
committed
get_chemical_symbols function and test
1 parent 39d41f0 commit b252ea3

2 files changed

Lines changed: 219 additions & 0 deletions

File tree

src/diffpy/structure/structure.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import copy as copymod
1818

1919
import numpy
20+
from ase import Atoms as ASEAtoms
2021

2122
from diffpy.structure.atom import Atom
2223
from diffpy.structure.lattice import Lattice
@@ -173,6 +174,179 @@ def getLastAtom(self):
173174
last_atom = self[-1]
174175
return last_atom
175176

177+
def get_chemical_symbols(self, include_charge_state=False):
178+
"""Return list of chemical symbols for all `Atoms` in this
179+
structure.
180+
181+
Parameters
182+
----------
183+
include_charge_state : bool, optional
184+
If ``True``, include charge state in the chemical symbol (e.g., "Fe2+").
185+
Returns
186+
-------
187+
list of str
188+
The list of chemical symbols for all `Atoms` in this structure.
189+
"""
190+
symbols_with_charge = [a.element for a in self]
191+
if include_charge_state:
192+
return symbols_with_charge
193+
else:
194+
symbols = [atomBareSymbol(sym) for sym in symbols_with_charge]
195+
return symbols
196+
197+
def get_fractional_coordinates(self):
198+
"""Return array of fractional coordinates of all `Atoms` in this
199+
structure.
200+
201+
Returns
202+
-------
203+
numpy.ndarray
204+
The array of fractional coordinates of all `Atoms` in this structure.
205+
"""
206+
coords = numpy.array([a.xyz for a in self])
207+
return coords
208+
209+
def get_cartesian_coordinates(self):
210+
"""Return array of Cartesian coordinates of all `Atoms` in this
211+
structure.
212+
213+
Returns
214+
-------
215+
numpy.ndarray
216+
The array of Cartesian coordinates of all `Atoms` in this structure.
217+
"""
218+
cartn_coords = numpy.array([a.xyz_cartn for a in self])
219+
return cartn_coords
220+
221+
def get_anisotropic_displacement_parameters(self):
222+
"""Return array of anisotropic displacement parameters of all
223+
`Atoms` in this structure.
224+
225+
Returns
226+
-------
227+
numpy.ndarray
228+
The array of anisotropic displacement parameters of all `Atoms` in this structure.
229+
"""
230+
adps = numpy.array([a.U for a in self])
231+
return adps
232+
233+
def get_isotropic_displacement_parameters(self):
234+
"""Return array of isotropic displacement parameters of all
235+
`Atoms` in this structure.
236+
237+
Returns
238+
-------
239+
numpy.ndarray
240+
The array of isotropic displacement parameters of all `Atoms` in this structure.
241+
"""
242+
idps = numpy.array([a.Uisoequiv for a in self])
243+
return idps
244+
245+
def get_occupancies(self):
246+
"""Return array of occupancies of all `Atoms` in this structure.
247+
248+
Returns
249+
-------
250+
numpy.ndarray
251+
The array of occupancies of all `Atoms` in this structure.
252+
"""
253+
occupancies = numpy.array([a.occupancy for a in self])
254+
return occupancies
255+
256+
def convert_ase_to_diffpy_structure(
257+
self,
258+
ase_atoms: ASEAtoms,
259+
lost_info: list[str] | None = None,
260+
) -> Structure | tuple[Structure, dict]: # noqa
261+
"""Convert ASE `Atoms` object to this `Structure` instance.
262+
263+
Parameters
264+
----------
265+
ase_structure : ase.Atoms
266+
The ASE `Atoms` object to be converted.
267+
lost_info : list of str, optional
268+
The list of attribute names to extract from the ASE `Atoms`
269+
object that do not have a direct equivalent in the `Structure` class.
270+
object that is not currently available in the `Structure` class.
271+
Default is False.
272+
273+
Returns
274+
-------
275+
Structure
276+
Reference to this `Structure` object with updated attributes and `Atom` instances.
277+
lost_info : dict, optional
278+
The dictionary containing any information from the ASE `Atoms`
279+
object that is not currently available in the `Structure` class.
280+
Default behavior is to return only the `Structure` instance.
281+
If `lost_info` is provided, it will be a dictionary containing
282+
any information from the ASE `Atoms`.
283+
This may include information such as magnetic moments, charge states,
284+
or other ASE-specific properties that do not have a direct equivalent
285+
in the `Structure` class.
286+
287+
Raises
288+
------
289+
TypeError
290+
If the input `ase_structure` is not an instance of `ase.Atoms`.
291+
ValueError
292+
If any of the specified `lost_info` attributes are not present in the ASE `Atoms` object.
293+
294+
Examples
295+
--------
296+
An example of converting an `ASE.Atoms` instance to a `Structure` instance,
297+
298+
.. code-block:: python
299+
from ase import Atoms
300+
from diffpy.structure import Structure
301+
302+
# Create an ASE Atoms object
303+
ase_atoms = Atoms('H2O', positions=[[0, 0, 0], [0, 0, 1], [1, 0, 0]])
304+
305+
# Convert to a diffpy Structure object
306+
structure = Structure()
307+
structure.convert_ase_to_diffpy(ase_atoms,
308+
309+
310+
To extract additional information from the ASE `Atoms` object that is not
311+
directly represented in the `Structure` class, such as magnetic moments,
312+
you can specify an attribute or method of `ASE.Atoms` as
313+
a list of strings in `lost_info` list. For example,
314+
315+
.. code-block:: python
316+
lost_info = structure.convert_ase_to_diffpy(
317+
ase_atoms,
318+
lost_info=['get_magnetic_moments']
319+
)
320+
321+
will return a dictionary with the magnetic moments of the atoms in the ASE `Atoms` object.
322+
"""
323+
if not isinstance(ase_atoms, ASEAtoms):
324+
raise TypeError("Input must be an instance of ase.Atoms.")
325+
# --- structure conversion ---
326+
symbols = ase_atoms.get_chemical_symbols()
327+
scaled_positions = ase_atoms.get_scaled_positions()
328+
for sym, xyz in zip(symbols, scaled_positions):
329+
self.append(Atom(sym, xyz=xyz))
330+
# --- optional extraction ---
331+
if lost_info is None:
332+
return
333+
extracted_info = {}
334+
for name in lost_info:
335+
if not hasattr(ase_atoms, name):
336+
raise ValueError(f"ASE.Atoms object has no attribute '{name}'.")
337+
try:
338+
attr = getattr(ase_atoms, name)
339+
value = attr() if callable(attr) else attr
340+
# try to copy (safe for numpy arrays, dicts, etc.)
341+
try:
342+
value = value.copy()
343+
except Exception:
344+
pass
345+
extracted_info[name] = value
346+
except Exception as e:
347+
extracted_info[name] = f"ERROR: {type(e).__name__}: {e}"
348+
return extracted_info
349+
176350
def assign_unique_labels(self):
177351
"""Set a unique label string for each `Atom` in this structure.
178352

tests/test_structure.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,51 @@ def test_pickling(self):
609609

610610
# End of class TestStructure
611611

612+
613+
@pytest.mark.parametrize(
614+
"include_charge_state,expected",
615+
[
616+
(False, ["Pb"] * 4 + ["Te"] * 4),
617+
(True, ["Pb"] * 4 + ["Te"] * 4),
618+
],
619+
)
620+
def test_get_chemical_symbols(datafile, include_charge_state, expected):
621+
"""Check Structure.get_chemical_symbols()"""
622+
pbte_stru = Structure(filename=datafile("PbTe.cif"))
623+
actual_chemical_symbols = pbte_stru.get_chemical_symbols(include_charge_state=include_charge_state)
624+
expected_chemical_symbols = expected
625+
assert actual_chemical_symbols == expected_chemical_symbols
626+
627+
628+
# def test_get_fractional_coordinates(datafile):
629+
# """Check Structure.get_fractional_coordinates()"""
630+
# pbte_cif = Structure(filename=datafile("PbTe.cif"))
631+
# assert False
632+
633+
634+
# def test_get_cartesian_coordinates(datafile):
635+
# """Check Structure.get_cartesian_coordinates()"""
636+
# assert False
637+
638+
639+
# def test_get_anisotropic_displacement_parameters(datafile):
640+
# """Check Structure.get_anisotropic_displacement_parameters()"""
641+
# assert False
642+
643+
644+
# def test_get_isotropic_displacement_parameters(datafile):
645+
# """Check Structure.get_isotropic_displacement_parameters()"""
646+
# assert False
647+
648+
649+
# def test_get_occupancies(datafile):
650+
# """Check Structure.get_occupancies()"""
651+
# assert False
652+
653+
# def test_convert_ase_to_diffpy_structure(datafile):
654+
# """Check convert_ase_to_diffpy_structure()"""
655+
# assert False
656+
612657
# ----------------------------------------------------------------------------
613658

614659
if __name__ == "__main__":

0 commit comments

Comments
 (0)