1111import argparse
1212import conversation as convo
1313import retrieval .wikipedia as wp
14- from transformers import AutoTokenizer , AutoModelForCausalLM , AutoConfig
14+ from transformers import AutoTokenizer , AutoModelForCausalLM , AutoConfig , StoppingCriteria , StoppingCriteriaList
1515from accelerate import infer_auto_device_map , init_empty_weights
1616
1717
18+ class StopWordsCriteria (StoppingCriteria ):
19+ def __init__ (self , tokenizer , stop_words , stream_callback ):
20+ self ._tokenizer = tokenizer
21+ self ._stop_words = stop_words
22+ self ._partial_result = ''
23+ self ._stream_buffer = ''
24+ self ._stream_callback = stream_callback
25+
26+ def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor , ** kwargs ) -> bool :
27+ first = not self ._partial_result
28+ text = self ._tokenizer .decode (input_ids [0 , - 1 ])
29+ self ._partial_result += text
30+ for stop_word in self ._stop_words :
31+ if stop_word in self ._partial_result :
32+ return True
33+ if self ._stream_callback :
34+ if first :
35+ text = text .lstrip ()
36+ for stop_word in self ._stop_words :
37+ for i in range (len (stop_word ), 0 , - 1 ):
38+ if self ._partial_result .endswith (stop_word [0 :i ]):
39+ self ._stream_buffer += text
40+ return False
41+ self ._stream_callback (self ._stream_buffer + text )
42+ self ._stream_buffer = ''
43+ return False
44+
45+
1846class ChatModel :
1947 human_id = "<human>"
2048 bot_id = "<bot>"
@@ -54,7 +82,8 @@ def __init__(self, model_name, gpu_id, max_memory):
5482 )
5583 self ._tokenizer = AutoTokenizer .from_pretrained (model_name )
5684
57- def do_inference (self , prompt , max_new_tokens , do_sample , temperature , top_k ):
85+ def do_inference (self , prompt , max_new_tokens , do_sample , temperature , top_k , stream_callback = None ):
86+ stop_criteria = StopWordsCriteria (self ._tokenizer , [self .human_id ], stream_callback )
5887 inputs = (
5988 self ._tokenizer (prompt , return_tensors = 'pt' )
6089 .to (self ._model .device )
@@ -65,7 +94,8 @@ def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k):
6594 do_sample = do_sample ,
6695 temperature = temperature ,
6796 top_k = top_k ,
68- pad_token_id = self ._tokenizer .eos_token_id
97+ pad_token_id = self ._tokenizer .eos_token_id ,
98+ stopping_criteria = StoppingCriteriaList ([stop_criteria ]),
6999 )
70100 output = self ._tokenizer .batch_decode (outputs )[0 ]
71101
@@ -79,7 +109,7 @@ class OpenChatKitShell(cmd.Cmd):
79109 intro = "Welcome to OpenChatKit shell. Type /help or /? to list commands.\n "
80110 prompt = ">>> "
81111
82- def __init__ (self , gpu_id , model_name_or_path , max_tokens , sample , temperature , top_k , retrieval , max_memory ):
112+ def __init__ (self , gpu_id , model_name_or_path , max_tokens , sample , temperature , top_k , retrieval , max_memory , do_stream ):
83113 super ().__init__ ()
84114 self ._gpu_id = int (gpu_id )
85115 self ._model_name_or_path = model_name_or_path
@@ -89,6 +119,7 @@ def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature,
89119 self ._top_k = top_k
90120 self ._retrieval = retrieval
91121 self ._max_memory = max_memory
122+ self ._do_stream = do_stream
92123
93124 def preloop (self ):
94125 print (f"Loading { self ._model_name_or_path } to cuda:{ self ._gpu_id } ..." )
@@ -120,12 +151,13 @@ def do_say(self, arg):
120151 self ._max_tokens ,
121152 self ._sample ,
122153 self ._temperature ,
123- self ._top_k
154+ self ._top_k ,
155+ lambda x : print (x , end = '' , flush = True ) if self ._do_stream else None ,
124156 )
125157
126158 self ._convo .push_model_response (output )
127159
128- print (self ._convo .get_last_turn ())
160+ print ("" if self . _do_stream else self ._convo .get_last_turn ())
129161
130162 def do_raw_say (self , arg ):
131163 output = self ._model .do_inference (
@@ -183,6 +215,11 @@ def main():
183215 action = 'store_true' ,
184216 help = 'indicates whether to sample'
185217 )
218+ parser .add_argument (
219+ '--no-stream' ,
220+ action = 'store_true' ,
221+ help = 'indicates whether to stream tokens'
222+ )
186223 parser .add_argument (
187224 '--temperature' ,
188225 default = 0.6 ,
@@ -238,7 +275,8 @@ def main():
238275 args .temperature ,
239276 args .top_k ,
240277 args .retrieval ,
241- max_memory
278+ max_memory ,
279+ not args .no_stream ,
242280 ).cmdloop ()
243281
244282
0 commit comments