1111import argparse
1212import conversation as convo
1313import retrieval .wikipedia as wp
14- from transformers import AutoTokenizer , AutoModelForCausalLM , AutoConfig
14+ from transformers import AutoTokenizer , AutoModelForCausalLM , AutoConfig , StoppingCriteria , StoppingCriteriaList
1515from accelerate import infer_auto_device_map , init_empty_weights
1616
1717
18+ class StopWordsCriteria (StoppingCriteria ):
19+ def __init__ (self , tokenizer , stop_words , stream_callback ):
20+ self ._tokenizer = tokenizer
21+ self ._stop_words = stop_words
22+ self ._partial_result = ''
23+ self ._stream_buffer = ''
24+ self ._stream_callback = stream_callback
25+
26+ def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor , ** kwargs ) -> bool :
27+ first = not self ._partial_result
28+ text = self ._tokenizer .decode (input_ids [0 , - 1 ])
29+ self ._partial_result += text
30+ for stop_word in self ._stop_words :
31+ if stop_word in self ._partial_result :
32+ return True
33+ if self ._stream_callback :
34+ if first :
35+ text = text .lstrip ()
36+ # buffer tokens if the partial result ends with a prefix of a stop word, e.g. "<hu"
37+ for stop_word in self ._stop_words :
38+ for i in range (1 , len (stop_word )):
39+ if self ._partial_result .endswith (stop_word [0 :i ]):
40+ self ._stream_buffer += text
41+ return False
42+ self ._stream_callback (self ._stream_buffer + text )
43+ self ._stream_buffer = ''
44+ return False
45+
46+
1847class ChatModel :
1948 human_id = "<human>"
2049 bot_id = "<bot>"
@@ -54,7 +83,8 @@ def __init__(self, model_name, gpu_id, max_memory):
5483 )
5584 self ._tokenizer = AutoTokenizer .from_pretrained (model_name )
5685
57- def do_inference (self , prompt , max_new_tokens , do_sample , temperature , top_k ):
86+ def do_inference (self , prompt , max_new_tokens , do_sample , temperature , top_k , stream_callback = None ):
87+ stop_criteria = StopWordsCriteria (self ._tokenizer , [self .human_id ], stream_callback )
5888 inputs = (
5989 self ._tokenizer (prompt , return_tensors = 'pt' )
6090 .to (self ._model .device )
@@ -65,7 +95,8 @@ def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k):
6595 do_sample = do_sample ,
6696 temperature = temperature ,
6797 top_k = top_k ,
68- pad_token_id = self ._tokenizer .eos_token_id
98+ pad_token_id = self ._tokenizer .eos_token_id ,
99+ stopping_criteria = StoppingCriteriaList ([stop_criteria ]),
69100 )
70101 output = self ._tokenizer .batch_decode (outputs )[0 ]
71102
@@ -79,7 +110,7 @@ class OpenChatKitShell(cmd.Cmd):
79110 intro = "Welcome to OpenChatKit shell. Type /help or /? to list commands.\n "
80111 prompt = ">>> "
81112
82- def __init__ (self , gpu_id , model_name_or_path , max_tokens , sample , temperature , top_k , retrieval , max_memory ):
113+ def __init__ (self , gpu_id , model_name_or_path , max_tokens , sample , temperature , top_k , retrieval , max_memory , do_stream ):
83114 super ().__init__ ()
84115 self ._gpu_id = int (gpu_id )
85116 self ._model_name_or_path = model_name_or_path
@@ -89,6 +120,7 @@ def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature,
89120 self ._top_k = top_k
90121 self ._retrieval = retrieval
91122 self ._max_memory = max_memory
123+ self ._do_stream = do_stream
92124
93125 def preloop (self ):
94126 print (f"Loading { self ._model_name_or_path } to cuda:{ self ._gpu_id } ..." )
@@ -120,12 +152,13 @@ def do_say(self, arg):
120152 self ._max_tokens ,
121153 self ._sample ,
122154 self ._temperature ,
123- self ._top_k
155+ self ._top_k ,
156+ lambda x : print (x , end = '' , flush = True ) if self ._do_stream else None ,
124157 )
125158
126159 self ._convo .push_model_response (output )
127160
128- print (self ._convo .get_last_turn ())
161+ print ("" if self . _do_stream else self ._convo .get_last_turn ())
129162
130163 def do_raw_say (self , arg ):
131164 output = self ._model .do_inference (
@@ -183,6 +216,11 @@ def main():
183216 action = 'store_true' ,
184217 help = 'indicates whether to sample'
185218 )
219+ parser .add_argument (
220+ '--no-stream' ,
221+ action = 'store_true' ,
222+ help = 'indicates whether to stream tokens'
223+ )
186224 parser .add_argument (
187225 '--temperature' ,
188226 default = 0.6 ,
@@ -238,7 +276,8 @@ def main():
238276 args .temperature ,
239277 args .top_k ,
240278 args .retrieval ,
241- max_memory
279+ max_memory ,
280+ not args .no_stream ,
242281 ).cmdloop ()
243282
244283
0 commit comments