88set of operations that are provided by this module and wrapped into a clean API.
99"""
1010import numpy as np
11+ from numba import TypingError
1112from datetime import datetime
1213from exatomic .base import sym2z
1314from .orbital_util import (
1415 numerical_grid_from_field_params , _determine_fps ,
1516 _determine_vector , _compute_orb_ang_mom , _compute_current_density ,
16- _compute_orbitals , _compute_density , _check_column , _make_field ,
17- _compute_orbitals_nojit )
17+ _compute_density , _check_column , _make_field ,
18+ _compute_orbitals_numba , _compute_orbitals_numpy )
1819
1920
2021def _setup_orbital (uni , verbose , vector , fps , icoefs , jcoefs = None , irrep = None ):
@@ -38,17 +39,20 @@ def _setup_orbital(uni, verbose, vector, fps, icoefs, jcoefs=None, irrep=None):
3839 return t1 , vector , fps , x , y , z , bvs , icoefs , jcoefs
3940 return t1 , vector , fps , x , y , z , bvs , icoefs
4041
42+ def _compute_orbital (verbose , npts , bvs , vector , cmat ):
43+ try : ovs = _compute_orbitals_numba (npts , bvs , vector , cmat )
44+ except (ValueError , IndexError , AssertionError , TypingError ) as e :
45+ if verbose : print ('numba eval failed, falling back to numpy' )
46+ ovs = _compute_orbitals_numpy (npts , bvs , vector , cmat )
47+ return ovs
4148
42- def _teardown_orbital (uni , verbose , field , t1 , inplace , replace , dens = False ):
49+ def _teardown_orbital (uni , verbose , field , t1 , inplace , name = 'orbitals' ):
4350 """Boilerplate for finishing the functions in this module."""
4451 if verbose :
4552 t2 = datetime .now ()
46- kind = 'density ' if dens else 'orbitals'
47- p2 = 'Timing: compute {} - {:>8.2f}s.'
48- print (p2 .format (kind , (t2 - t1 ).total_seconds ()))
53+ p2 = 'Timing: compute {:<8} - {:>8.2f}s.'
54+ print (p2 .format (name , (t2 - t1 ).total_seconds ()))
4955 if not inplace : return field
50- if replace and hasattr (uni , '_field' ):
51- del uni .__dict__ ['_field' ]
5256 uni .add_field (field )
5357
5458
@@ -76,14 +80,12 @@ def add_molecular_orbitals(uni, field_params=None, mocoefs=None,
7680 Warning:
7781 If replace is True, removes any fields previously attached to the universe
7882 """
83+ if replace and hasattr (uni , '_field' ): del uni .__dict__ ['_field' ]
7984 t1 , vector , fps , x , y , z , bvs , mocoefs = \
8085 _setup_orbital (uni , verbose , vector , field_params , mocoefs , irrep = irrep )
81- try : ovs = _compute_orbitals (len (x ), bvs , vector , mocoefs )
82- except (ValueError , IndexError , AssertionError ) as e :
83- if verbose : print ('Falling back to numpy orbital evaluation.' )
84- ovs = _compute_orbitals_nojit (len (x ), bvs , vector , mocoefs )
86+ ovs = _compute_orbital (verbose , len (x ), bvs , vector , mocoefs )
8587 field = _make_field (ovs , fps )
86- return _teardown_orbital (uni , verbose , field , t1 , inplace , replace )
88+ return _teardown_orbital (uni , verbose , field , t1 , inplace )
8789
8890
8991def add_density (uni , field_params = None , mocoefs = None , orbocc = None ,
@@ -106,13 +108,9 @@ def add_density(uni, field_params=None, mocoefs=None, orbocc=None,
106108 orbocc = _check_column (uni , 'orbital' , orbocc )
107109 vector = uni .orbital [~ np .isclose (uni .orbital [orbocc ], 0 )].index .values
108110 orbocc = uni .orbital .loc [vector ][orbocc ].values
109- try : ovs = _compute_orbitals (len (x ), bvs , vector , mocoefs )
110- except (ValueError , IndexError , AssertionError ) as e :
111- #if verbose:
112- print ('Falling back to numpy orbital evaluation.' )
113- ovs = _compute_orbitals_nojit (len (x ), bvs , vector , mocoefs )
111+ ovs = _compute_orbital (verbose , len (x ), bvs , vector , mocoefs )
114112 field = _make_field (_compute_density (ovs , orbocc ), fps .loc [0 ])
115- return _teardown_orbital (uni , verbose , field , t1 , inplace , False )
113+ return _teardown_orbital (uni , verbose , field , t1 , inplace , name = 'density' )
116114
117115
118116def add_orb_ang_mom (uni , field_params = None , rcoefs = None , icoefs = None ,
@@ -150,8 +148,6 @@ def add_orb_ang_mom(uni, field_params=None, rcoefs=None, icoefs=None,
150148 if verbose :
151149 p1 = 'Timing: grid evaluation - {:>8.2f}s.'
152150 print (p1 .format ((t2 - t1 ).total_seconds ()))
153- print (rcoefs .shape , rcoefs .dtype )
154- print (icoefs .shape , icoefs .dtype )
155151 curx , cury , curz = _compute_current_density (
156152 bvs , grx , gry , grz , rcoefs , icoefs , occvec , verbose = verbose )
157153 t3 = datetime .now ()
@@ -160,4 +156,4 @@ def add_orb_ang_mom(uni, field_params=None, rcoefs=None, icoefs=None,
160156 print (p2 .format ((t3 - t2 ).total_seconds ()))
161157 field = _make_field (_compute_orb_ang_mom (
162158 x , y , z , curx , cury , curz , maxes ), fps )
163- return _teardown_orbital (uni , False , field , t1 , inplace , False )
159+ return _teardown_orbital (uni , verbose , field , t1 , inplace , name = 'angmom' )
0 commit comments