Skip to content

Commit ce04856

Browse files
author
Justin Foutts
committed
Add streaming tokens
1 parent aa09ce7 commit ce04856

1 file changed

Lines changed: 45 additions & 7 deletions

File tree

inference/bot.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,38 @@
1111
import argparse
1212
import conversation as convo
1313
import retrieval.wikipedia as wp
14-
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
14+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, StoppingCriteria, StoppingCriteriaList
1515
from accelerate import infer_auto_device_map, init_empty_weights
1616

1717

18+
class StopWordsCriteria(StoppingCriteria):
19+
def __init__(self, tokenizer, stop_words, stream_callback):
20+
self._tokenizer = tokenizer
21+
self._stop_words = stop_words
22+
self._partial_result = ''
23+
self._stream_buffer = ''
24+
self._stream_callback = stream_callback
25+
26+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
27+
first = not self._partial_result
28+
text = self._tokenizer.decode(input_ids[0, -1])
29+
self._partial_result += text
30+
for stop_word in self._stop_words:
31+
if stop_word in self._partial_result:
32+
return True
33+
if self._stream_callback:
34+
if first:
35+
text = text.lstrip()
36+
for stop_word in self._stop_words:
37+
for i in range(len(stop_word), 0, -1):
38+
if self._partial_result.endswith(stop_word[0:i]):
39+
self._stream_buffer += text
40+
return False
41+
self._stream_callback(self._stream_buffer + text)
42+
self._stream_buffer = ''
43+
return False
44+
45+
1846
class ChatModel:
1947
human_id = "<human>"
2048
bot_id = "<bot>"
@@ -54,7 +82,8 @@ def __init__(self, model_name, gpu_id, max_memory):
5482
)
5583
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
5684

57-
def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k):
85+
def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k, stream_callback=None):
86+
stop_criteria = StopWordsCriteria(self._tokenizer, [self.human_id], stream_callback)
5887
inputs = (
5988
self._tokenizer(prompt, return_tensors='pt')
6089
.to(self._model.device)
@@ -65,7 +94,8 @@ def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k):
6594
do_sample=do_sample,
6695
temperature=temperature,
6796
top_k=top_k,
68-
pad_token_id=self._tokenizer.eos_token_id
97+
pad_token_id=self._tokenizer.eos_token_id,
98+
stopping_criteria=StoppingCriteriaList([stop_criteria]),
6999
)
70100
output = self._tokenizer.batch_decode(outputs)[0]
71101

@@ -79,7 +109,7 @@ class OpenChatKitShell(cmd.Cmd):
79109
intro = "Welcome to OpenChatKit shell. Type /help or /? to list commands.\n"
80110
prompt = ">>> "
81111

82-
def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature, top_k, retrieval, max_memory):
112+
def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature, top_k, retrieval, max_memory, do_stream):
83113
super().__init__()
84114
self._gpu_id = int(gpu_id)
85115
self._model_name_or_path = model_name_or_path
@@ -89,6 +119,7 @@ def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature,
89119
self._top_k = top_k
90120
self._retrieval = retrieval
91121
self._max_memory = max_memory
122+
self._do_stream = do_stream
92123

93124
def preloop(self):
94125
print(f"Loading {self._model_name_or_path} to cuda:{self._gpu_id}...")
@@ -120,12 +151,13 @@ def do_say(self, arg):
120151
self._max_tokens,
121152
self._sample,
122153
self._temperature,
123-
self._top_k
154+
self._top_k,
155+
lambda x : print(x, end='', flush=True) if self._do_stream else None,
124156
)
125157

126158
self._convo.push_model_response(output)
127159

128-
print(self._convo.get_last_turn())
160+
print("" if self._do_stream else self._convo.get_last_turn())
129161

130162
def do_raw_say(self, arg):
131163
output = self._model.do_inference(
@@ -183,6 +215,11 @@ def main():
183215
action='store_true',
184216
help='indicates whether to sample'
185217
)
218+
parser.add_argument(
219+
'--no-stream',
220+
action='store_true',
221+
help='indicates whether to stream tokens'
222+
)
186223
parser.add_argument(
187224
'--temperature',
188225
default=0.6,
@@ -238,7 +275,8 @@ def main():
238275
args.temperature,
239276
args.top_k,
240277
args.retrieval,
241-
max_memory
278+
max_memory,
279+
not args.no_stream,
242280
).cmdloop()
243281

244282

0 commit comments

Comments
 (0)