Skip to content

Commit a71963d

Browse files
authored
Merge pull request #75 from togethercomputer/add-streaming-tokens
Add streaming tokens
2 parents aa09ce7 + aff0827 commit a71963d

1 file changed

Lines changed: 46 additions & 7 deletions

File tree

inference/bot.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,39 @@
1111
import argparse
1212
import conversation as convo
1313
import retrieval.wikipedia as wp
14-
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
14+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, StoppingCriteria, StoppingCriteriaList
1515
from accelerate import infer_auto_device_map, init_empty_weights
1616

1717

18+
class StopWordsCriteria(StoppingCriteria):
19+
def __init__(self, tokenizer, stop_words, stream_callback):
20+
self._tokenizer = tokenizer
21+
self._stop_words = stop_words
22+
self._partial_result = ''
23+
self._stream_buffer = ''
24+
self._stream_callback = stream_callback
25+
26+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
27+
first = not self._partial_result
28+
text = self._tokenizer.decode(input_ids[0, -1])
29+
self._partial_result += text
30+
for stop_word in self._stop_words:
31+
if stop_word in self._partial_result:
32+
return True
33+
if self._stream_callback:
34+
if first:
35+
text = text.lstrip()
36+
# buffer tokens if the partial result ends with a prefix of a stop word, e.g. "<hu"
37+
for stop_word in self._stop_words:
38+
for i in range(1, len(stop_word)):
39+
if self._partial_result.endswith(stop_word[0:i]):
40+
self._stream_buffer += text
41+
return False
42+
self._stream_callback(self._stream_buffer + text)
43+
self._stream_buffer = ''
44+
return False
45+
46+
1847
class ChatModel:
1948
human_id = "<human>"
2049
bot_id = "<bot>"
@@ -54,7 +83,8 @@ def __init__(self, model_name, gpu_id, max_memory):
5483
)
5584
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
5685

57-
def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k):
86+
def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k, stream_callback=None):
87+
stop_criteria = StopWordsCriteria(self._tokenizer, [self.human_id], stream_callback)
5888
inputs = (
5989
self._tokenizer(prompt, return_tensors='pt')
6090
.to(self._model.device)
@@ -65,7 +95,8 @@ def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k):
6595
do_sample=do_sample,
6696
temperature=temperature,
6797
top_k=top_k,
68-
pad_token_id=self._tokenizer.eos_token_id
98+
pad_token_id=self._tokenizer.eos_token_id,
99+
stopping_criteria=StoppingCriteriaList([stop_criteria]),
69100
)
70101
output = self._tokenizer.batch_decode(outputs)[0]
71102

@@ -79,7 +110,7 @@ class OpenChatKitShell(cmd.Cmd):
79110
intro = "Welcome to OpenChatKit shell. Type /help or /? to list commands.\n"
80111
prompt = ">>> "
81112

82-
def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature, top_k, retrieval, max_memory):
113+
def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature, top_k, retrieval, max_memory, do_stream):
83114
super().__init__()
84115
self._gpu_id = int(gpu_id)
85116
self._model_name_or_path = model_name_or_path
@@ -89,6 +120,7 @@ def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature,
89120
self._top_k = top_k
90121
self._retrieval = retrieval
91122
self._max_memory = max_memory
123+
self._do_stream = do_stream
92124

93125
def preloop(self):
94126
print(f"Loading {self._model_name_or_path} to cuda:{self._gpu_id}...")
@@ -120,12 +152,13 @@ def do_say(self, arg):
120152
self._max_tokens,
121153
self._sample,
122154
self._temperature,
123-
self._top_k
155+
self._top_k,
156+
lambda x : print(x, end='', flush=True) if self._do_stream else None,
124157
)
125158

126159
self._convo.push_model_response(output)
127160

128-
print(self._convo.get_last_turn())
161+
print("" if self._do_stream else self._convo.get_last_turn())
129162

130163
def do_raw_say(self, arg):
131164
output = self._model.do_inference(
@@ -183,6 +216,11 @@ def main():
183216
action='store_true',
184217
help='indicates whether to sample'
185218
)
219+
parser.add_argument(
220+
'--no-stream',
221+
action='store_true',
222+
help='indicates whether to stream tokens'
223+
)
186224
parser.add_argument(
187225
'--temperature',
188226
default=0.6,
@@ -238,7 +276,8 @@ def main():
238276
args.temperature,
239277
args.top_k,
240278
args.retrieval,
241-
max_memory
279+
max_memory,
280+
not args.no_stream,
242281
).cmdloop()
243282

244283

0 commit comments

Comments
 (0)