11from contextlib import contextmanager
2+ from enum import Enum
23
34from .common import KeyStreamIterator
45from .lib import ffi , lib , checked_call
@@ -55,14 +56,40 @@ def get_set(self):
5556 return Set (None , _pointer = self ._set_ptr )
5657
5758
59+ class OpBuilderInputType (Enum ):
60+ SET = 1
61+ STREAM_BUILDER = 2
62+
63+
5864class OpBuilder (object ):
59- def __init__ (self , set_ptr ):
65+
66+ _BUILDERS = {
67+ OpBuilderInputType .SET : lib .fst_set_make_opbuilder ,
68+ OpBuilderInputType .STREAM_BUILDER : lib .fst_set_make_opbuilder_streambuilder ,
69+ }
70+ _PUSHERS = {
71+ OpBuilderInputType .SET : lib .fst_set_opbuilder_push ,
72+ OpBuilderInputType .STREAM_BUILDER : lib .fst_set_opbuilder_push_streambuilder ,
73+ }
74+
75+ @classmethod
76+ def from_slice (cls , set_ptr , s ):
77+ sb = StreamBuilder .from_slice (set_ptr , s )
78+ opbuilder = OpBuilder (sb ._ptr ,
79+ input_type = OpBuilderInputType .STREAM_BUILDER )
80+ return opbuilder
81+
82+ def __init__ (self , ptr , input_type = OpBuilderInputType .SET ):
83+ if input_type not in self ._BUILDERS :
84+ raise ValueError (
85+ "input_type must be a member of OpBuilderInputType." )
86+ self ._input_type = input_type
6087 # NOTE: No need for `ffi.gc`, since the struct will be free'd
6188 # once we call union/intersection/difference
62- self ._ptr = lib . fst_set_make_opbuilder ( set_ptr )
89+ self ._ptr = OpBuilder . _BUILDERS [ self . _input_type ]( ptr )
6390
64- def push (self , set_ptr ):
65- lib . fst_set_opbuilder_push (self ._ptr , set_ptr )
91+ def push (self , ptr ):
92+ OpBuilder . _PUSHERS [ self . _input_type ] (self ._ptr , ptr )
6693
6794 def union (self ):
6895 stream_ptr = lib .fst_set_opbuilder_union (self ._ptr )
@@ -86,6 +113,44 @@ def symmetric_difference(self):
86113 lib .fst_set_symmetricdifference_free )
87114
88115
116+ class StreamBuilder (object ):
117+
118+ @classmethod
119+ def from_slice (cls , set_ptr , slice_bounds ):
120+ sb = StreamBuilder (set_ptr )
121+ if slice_bounds .start :
122+ sb .ge (slice_bounds .start )
123+ if slice_bounds .stop :
124+ sb .lt (slice_bounds .stop )
125+ return sb
126+
127+ def __init__ (self , set_ptr ):
128+ # NOTE: No need for `ffi.gc`, since the struct will be free'd
129+ # once we call union/intersection/difference
130+ self ._ptr = lib .fst_set_streambuilder_new (set_ptr )
131+
132+ def finish (self ):
133+ stream_ptr = lib .fst_set_streambuilder_finish (self ._ptr )
134+ return KeyStreamIterator (stream_ptr , lib .fst_set_stream_next ,
135+ lib .fst_set_stream_free )
136+
137+ def ge (self , bound ):
138+ c_start = ffi .new ("char[]" , bound .encode ('utf8' ))
139+ self ._ptr = lib .fst_set_streambuilder_add_ge (self ._ptr , c_start )
140+
141+ def gt (self , bound ):
142+ c_start = ffi .new ("char[]" , bound .encode ('utf8' ))
143+ self ._ptr = lib .fst_set_streambuilder_add_gt (self ._ptr , c_start )
144+
145+ def le (self , bound ):
146+ c_end = ffi .new ("char[]" , bound .encode ('utf8' ))
147+ self ._ptr = lib .fst_set_streambuilder_add_le (self ._ptr , c_end )
148+
149+ def lt (self , bound ):
150+ c_end = ffi .new ("char[]" , bound .encode ('utf8' ))
151+ self ._ptr = lib .fst_set_streambuilder_add_lt (self ._ptr , c_end )
152+
153+
89154class Set (object ):
90155 """ An immutable ordered string set backed by a finite state transducer.
91156
@@ -203,19 +268,11 @@ def __getitem__(self, s):
203268 if s .start and s .stop and s .start > s .stop :
204269 raise ValueError (
205270 "Start key must be lexicographically smaller than stop." )
206- sb_ptr = lib .fst_set_streambuilder_new (self ._ptr )
207- if s .start :
208- c_start = ffi .new ("char[]" , s .start .encode ('utf8' ))
209- sb_ptr = lib .fst_set_streambuilder_add_ge (sb_ptr , c_start )
210- if s .stop :
211- c_stop = ffi .new ("char[]" , s .stop .encode ('utf8' ))
212- sb_ptr = lib .fst_set_streambuilder_add_lt (sb_ptr , c_stop )
213- stream_ptr = lib .fst_set_streambuilder_finish (sb_ptr )
214- return KeyStreamIterator (stream_ptr , lib .fst_set_stream_next ,
215- lib .fst_set_stream_free )
271+ sb = StreamBuilder .from_slice (self ._ptr , s )
272+ return sb .finish ()
216273
217274 def _make_opbuilder (self , * others ):
218- opbuilder = OpBuilder (self ._ptr )
275+ opbuilder = OpBuilder (self ._ptr , input_type = OpBuilderInputType . SET )
219276 for oth in others :
220277 opbuilder .push (oth ._ptr )
221278 return opbuilder
@@ -333,3 +390,65 @@ def search(self, term, max_dist):
333390 return KeyStreamIterator (stream_ptr , lib .fst_set_levstream_next ,
334391 lib .fst_set_levstream_free , lev_ptr ,
335392 lib .fst_levenshtein_free )
393+
394+
395+ class UnionSet (object ):
396+ """ A collection of Set objects that offer efficient operations across all
397+ members.
398+ """
399+ def __init__ (self , * sets ):
400+ self .sets = list (sets )
401+
402+ def __contains__ (self , val ):
403+ """ Check if the set contains the value. """
404+ return any ([
405+ lib .fst_set_contains (fst ._ptr ,
406+ ffi .new ("char[]" ,
407+ val .encode ('utf8' )))
408+ for fst in self .sets
409+ ])
410+
411+ def __getitem__ (self , s ):
412+ """ Get an iterator over a range of set contents.
413+
414+ Start and stop indices of the slice must be unicode strings.
415+
416+ .. important::
417+ Slicing follows the semantics for numerical indices, i.e. the
418+ `stop` value is **exclusive**. For example, given the set
419+ `s = Set.from_iter(["bar", "baz", "foo", "moo"])`, `s['b': 'f']`
420+ will only return `"bar"` and `"baz"`.
421+
422+ :param s: A slice that specifies the range of the set to retrieve
423+ :type s: :py:class:`slice`
424+ """
425+ if not isinstance (s , slice ):
426+ raise ValueError (
427+ "Value must be a string slice (e.g. `['foo':]`)" )
428+ if s .start and s .stop and s .start > s .stop :
429+ raise ValueError (
430+ "Start key must be lexicographically smaller than stop." )
431+ if len (self .sets ) <= 1 :
432+ raise ValueError (
433+ "Must have more than one set to operate on." )
434+
435+ opbuilder = OpBuilder .from_slice (self .sets [0 ]._ptr , s )
436+ streams = []
437+ for fst in self .sets [1 :]:
438+ sb = StreamBuilder .from_slice (fst ._ptr , s )
439+ streams .append (sb )
440+ for sb in streams :
441+ opbuilder .push (sb ._ptr )
442+ return opbuilder .union ()
443+
444+ def __iter__ (self ):
445+ """ Get an iterator over all keys in all sets in lexicographical order.
446+ """
447+ if len (self .sets ) <= 1 :
448+ raise ValueError (
449+ "Must have more than one set to operate on." )
450+ opbuilder = OpBuilder (self .sets [0 ]._ptr ,
451+ input_type = OpBuilderInputType .SET )
452+ for fst in self .sets [1 :]:
453+ opbuilder .push (fst ._ptr )
454+ return opbuilder .union ()
0 commit comments