Skip to content

Commit 6a25991

Browse files
committed
chore: changeg unittest to pytest
1 parent d19c025 commit 6a25991

1 file changed

Lines changed: 42 additions & 12 deletions

File tree

tests/test_symmetryutilities.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import unittest
2020

2121
import numpy
22+
import pytest
2223

2324
from diffpy.structure.spacegroups import GetSpaceGroup
2425
from diffpy.structure.symmetryutilities import (
@@ -109,24 +110,12 @@ def test_positionDifference(self):
109110
self.assertTrue(numpy.allclose(positionDifference([1.2, -0.1, 2.75], [0.1, 0.4, 0.25]), [0.1, 0.5, 0.5]))
110111
return
111112

112-
def test_position_difference(self):
113-
"""Check position_difference in normal and boundary cases."""
114-
self.assertTrue(numpy.allclose(position_difference([0.1, 0.9, 0.2], [0.8, 0.1, 0.8]), [0.3, 0.2, 0.4]))
115-
self.assertTrue(numpy.allclose(position_difference([1.2, -0.1, 2.75], [0.1, 0.4, 0.25]), [0.1, 0.5, 0.5]))
116-
return
117-
118113
def test_nearestSiteIndex(self):
119114
"""Check nearestSiteIndex with single and multiple sites."""
120115
self.assertEqual(nearestSiteIndex([[0.1, 0.9, 0.2], [0.8, 0.1, 0.8]], [0.8, 0.1, 0.8]), 1)
121116
self.assertEqual(nearestSiteIndex([[1.2, -0.1, 2.75]], [0.7, 0.4, 0.25]), 0)
122117
return
123118

124-
def test_nearest_site_index(self):
125-
"""Check nearest_site_index with single and multiple sites."""
126-
self.assertEqual(nearest_site_index([[0.1, 0.9, 0.2], [0.8, 0.1, 0.8]], [0.8, 0.1, 0.8]), 1)
127-
self.assertEqual(nearest_site_index([[1.2, -0.1, 2.75]], [0.7, 0.4, 0.25]), 0)
128-
return
129-
130119
def test_expandPosition(self):
131120
"""Check expandPosition()"""
132121
# ok again Ni example
@@ -674,5 +663,46 @@ def test_UparValues(self):
674663

675664
# ----------------------------------------------------------------------------
676665

666+
667+
@pytest.mark.parametrize(
668+
"xyz0, xyz1, expected",
669+
[
670+
pytest.param( # C1: Generic case for symmetry mapping for periodic lattice
671+
[0.1, 0.9, 0.2],
672+
[0.8, 0.1, 0.8],
673+
[0.3, 0.2, 0.4],
674+
),
675+
pytest.param( # C2: Boundary case for entries with mapping on difference equal to 0.5
676+
[1.2, -0.1, 2.75],
677+
[0.1, 0.4, 0.25],
678+
[0.1, 0.5, 0.5],
679+
),
680+
],
681+
)
682+
def test_position_difference(xyz0, xyz1, expected):
683+
actual = position_difference(xyz0, xyz1)
684+
assert numpy.allclose(actual, expected)
685+
686+
687+
@pytest.mark.parametrize(
688+
"sites, xyz, expected",
689+
[
690+
pytest.param( # C1: We have two sites, and the xyz is closest to the index 1 site
691+
[[0.1, 0.9, 0.2], [0.8, 0.1, 0.8]],
692+
[0.8, 0.1, 0.8],
693+
1,
694+
),
695+
pytest.param( # C2: we have one site, and the xyz is closest to the index 0 site by default
696+
[[1.2, -0.1, 2.75]],
697+
[0.7, 0.4, 0.25],
698+
0,
699+
),
700+
],
701+
)
702+
def test_nearest_site_index(sites, xyz, expected):
703+
actual = nearest_site_index(sites, xyz)
704+
assert actual == expected
705+
706+
677707
if __name__ == "__main__":
678708
unittest.main()

0 commit comments

Comments
 (0)