@@ -126,13 +126,13 @@ def __init__(self, grid=None, edges=None, origin=None, delta=None,
126126 self .load (grid )
127127 elif not (grid is None or edges is None ):
128128 # set up from histogramdd-type data
129- self .grid = numpy .asarray (grid )
129+ self .grid = numpy .asanyarray (grid )
130130 self .edges = edges
131131 self ._update ()
132132 elif not (grid is None or origin is None or delta is None ):
133133 # setup from generic data
134- origin = numpy .asarray (origin )
135- delta = numpy .asarray (delta )
134+ origin = numpy .asanyarray (origin )
135+ delta = numpy .asanyarray (delta )
136136 if len (origin ) != grid .ndim :
137137 raise TypeError (
138138 "Dimension of origin is not the same as grid dimension." )
@@ -148,7 +148,7 @@ def __init__(self, grid=None, edges=None, origin=None, delta=None,
148148 self .edges = [origin [dim ] +
149149 (numpy .arange (m + 1 ) - 0.5 ) * delta [dim ]
150150 for dim , m in enumerate (grid .shape )]
151- self .grid = numpy .asarray (grid )
151+ self .grid = numpy .asanyarray (grid )
152152 self ._update ()
153153 else :
154154 # empty, must manually populate with load()
@@ -230,7 +230,7 @@ def resample(self, edges):
230230 coordinates = ndmeshgrid (* midpoints )
231231 # feed a meshgrid to generate all points
232232 newgrid = self .interpolated (* coordinates )
233- return Grid (newgrid , edges )
233+ return self . __class__ (newgrid , edges )
234234
235235 def resample_factor (self , factor ):
236236 """Resample to a new regular grid.
@@ -611,76 +611,76 @@ def __ne__(self, other):
611611
612612 def __add__ (self , other ):
613613 self .check_compatible (other )
614- return Grid (self .grid + _grid (other ), edges = self .edges )
614+ return self . __class__ (self .grid + _grid (other ), edges = self .edges )
615615
616616 def __sub__ (self , other ):
617617 self .check_compatible (other )
618- return Grid (self .grid - _grid (other ), edges = self .edges )
618+ return self . __class__ (self .grid - _grid (other ), edges = self .edges )
619619
620620 def __mul__ (self , other ):
621621 self .check_compatible (other )
622- return Grid (self .grid * _grid (other ), edges = self .edges )
622+ return self . __class__ (self .grid * _grid (other ), edges = self .edges )
623623
624624 def __truediv__ (self , other ):
625625 # truediv will always do true division (in Python 2 and Python 3);
626626 # we use from __future__ include division everywhere
627627 self .check_compatible (other )
628- return Grid (self .grid / _grid (other ), edges = self .edges )
628+ return self . __class__ (self .grid / _grid (other ), edges = self .edges )
629629
630630 def __div__ (self , other ):
631631 # in Python 2 only (without __future__.division): will do "classic division"
632632 # https://docs.python.org/2/reference/datamodel.html#object.__div__
633633 if not six .PY2 :
634634 raise NotImplementedError ("__div__ is only available in Python 2, use __truediv__" )
635635 self .check_compatible (other )
636- return Grid (self .grid .__div__ (_grid (other )), edges = self .edges )
636+ return self . __class__ (self .grid .__div__ (_grid (other )), edges = self .edges )
637637
638638 def __floordiv__ (self , other ):
639639 self .check_compatible (other )
640- return Grid (self .grid // _grid (other ), edges = self .edges )
640+ return self . __class__ (self .grid // _grid (other ), edges = self .edges )
641641
642642 def __pow__ (self , other ):
643643 self .check_compatible (other )
644- return Grid (numpy .power (self .grid , _grid (other )), edges = self .edges )
644+ return self . __class__ (numpy .power (self .grid , _grid (other )), edges = self .edges )
645645
646646 def __radd__ (self , other ):
647647 self .check_compatible (other )
648- return Grid (_grid (other ) + self .grid , edges = self .edges )
648+ return self . __class__ (_grid (other ) + self .grid , edges = self .edges )
649649
650650 def __rsub__ (self , other ):
651651 self .check_compatible (other )
652- return Grid (_grid (other ) - self .grid , edges = self .edges )
652+ return self . __class__ (_grid (other ) - self .grid , edges = self .edges )
653653
654654 def __rmul__ (self , other ):
655655 self .check_compatible (other )
656- return Grid (_grid (other ) * self .grid , edges = self .edges )
656+ return self . __class__ (_grid (other ) * self .grid , edges = self .edges )
657657
658658 def __rtruediv__ (self , other ):
659659 self .check_compatible (other )
660- return Grid (_grid (other ) / self .grid , edges = self .edges )
660+ return self . __class__ (_grid (other ) / self .grid , edges = self .edges )
661661
662662 def __rdiv__ (self , other ):
663663 # in Python 2 only (without __future__.division): will do "classic division"
664664 # https://docs.python.org/2/reference/datamodel.html#object.__div__
665665 if not six .PY2 :
666666 raise NotImplementedError ("__rdiv__ is only available in Python 2, use __rtruediv__" )
667667 self .check_compatible (other )
668- return Grid (self .grid .__rdiv__ (_grid (other )), edges = self .edges )
668+ return self . __class__ (self .grid .__rdiv__ (_grid (other )), edges = self .edges )
669669
670670 def __rfloordiv__ (self , other ):
671671 self .check_compatible (other )
672- return Grid (_grid (other ) // self .grid , edges = self .edges )
672+ return self . __class__ (_grid (other ) // self .grid , edges = self .edges )
673673
674674 def __rpow__ (self , other ):
675675 self .check_compatible (other )
676- return Grid (numpy .power (_grid (other ), self .grid ), edges = self .edges )
676+ return self . __class__ (numpy .power (_grid (other ), self .grid ), edges = self .edges )
677677
678678 def __repr__ (self ):
679679 try :
680680 bins = self .grid .shape
681681 except AttributeError :
682682 bins = "no"
683- return '<Grid with ' + str ( bins ) + ' bins>'
683+ return '<{0} with {1!r} bins>' . format ( self . __class__ , bins )
684684
685685
686686def ndmeshgrid (* arrs ):
@@ -709,7 +709,7 @@ def ndmeshgrid(*arrs):
709709 for i , arr in enumerate (arrs ):
710710 slc = [1 ] * dim
711711 slc [i ] = lens [i ]
712- arr2 = numpy .asarray (arr ).reshape (slc )
712+ arr2 = numpy .asanyarray (arr ).reshape (slc )
713713 for j , sz in enumerate (lens ):
714714 if j != i :
715715 arr2 = arr2 .repeat (sz , axis = j )
0 commit comments