Skip to content

Commit 140af99

Browse files
Dando18Daniel Nichols
andauthored
Updates and Optimizations to Generation Scripts (#9)
* updates and optimizations to generation scripts --------- Co-authored-by: Daniel Nichols <dnicho@login01.chn>
1 parent 4c9276c commit 140af99

4 files changed

Lines changed: 334 additions & 54 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# python
55
*.pyc
66
__pycache__
7+
.env
78

89
# cpp
910
*.out

generate.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

generate/generate.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# std imports
2+
import argparse
3+
import json
4+
import time
5+
from tqdm import tqdm
6+
import sys
7+
8+
# tpl imports
9+
import torch
10+
from transformers import pipeline
11+
12+
# local imports
13+
from utils import BalancedBracketsCriteria, PromptDataset, clean_output, get_inference_config
14+
15+
16+
""" Parse command line arguments """
17+
parser = argparse.ArgumentParser(description='Generate code')
18+
parser.add_argument('--prompts', required=True, help='Path to the prompt JSON file')
19+
parser.add_argument('--model', required=True, help='Path to the language model')
20+
parser.add_argument('--output', required=True, help='Path to the output JSON file')
21+
parser.add_argument('--max_new_tokens', type=int, default=1024, help='Maximum number of new tokens to generate (default: 1024)')
22+
parser.add_argument('--num_samples_per_prompt', type=int, default=50, help='Number of code samples to generate (default: 50)')
23+
parser.add_argument('--temperature', type=float, default=0.2, help='Temperature for controlling randomness (default: 0.2)')
24+
parser.add_argument('--top_p', type=float, default=0.95, help='Top p value for nucleus sampling (default: 0.95)')
25+
parser.add_argument('--do_sample', action='store_true', help='Enable sampling (default: False)')
26+
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for generation (default: 8)')
27+
parser.add_argument('--prompted', action='store_true', help='Use prompted generation. See StarCoder paper (default: False)')
28+
args = parser.parse_args()
29+
30+
""" Load prompts """
31+
with open(args.prompts, 'r') as json_file:
32+
prompts = json.load(json_file)
33+
34+
""" Initialize inference config """
35+
inference_config = get_inference_config(args.model, prompted=args.prompted)
36+
37+
# to use a torch.utils.data.DataSet with the HuggingFace pipeline, we need to flatten out the prompts
38+
# and repeat them for however many samples we want to generate per prompt
39+
prompts_repeated = [p for p in prompts for _ in range(args.num_samples_per_prompt)]
40+
41+
""" Initialize HuggingFace pipeline for generation """
42+
generator = pipeline(model=args.model, torch_dtype=inference_config.get_dtype(), device_map="auto")
43+
inference_config.init_padding(generator.tokenizer)
44+
45+
""" Create a prompt data set to pass to generate method """
46+
prompt_dataset = PromptDataset([inference_config.format_prompt(p["prompt"]) for p in prompts_repeated])
47+
generated_outputs = generator(
48+
prompt_dataset,
49+
max_new_tokens=args.max_new_tokens,
50+
do_sample=args.do_sample,
51+
temperature=args.temperature,
52+
top_p=args.top_p,
53+
pad_token_id=inference_config.get_pad_token_id(generator.tokenizer),
54+
eos_token_id=inference_config.get_eos_token_id(generator.tokenizer),
55+
batch_size=args.batch_size,
56+
)
57+
58+
""" Iterate over prompts and generate code """
59+
responses = []
60+
cur_prompt = None
61+
start_time = time.time()
62+
total_tokens = 0
63+
for idx, (prompt, output) in tqdm(enumerate(zip(prompts_repeated, generated_outputs)), total=len(prompts_repeated), desc="Generating code", file=sys.stdout):
64+
if idx % args.num_samples_per_prompt == 0:
65+
cur_prompt = prompt.copy()
66+
cur_prompt.update({"temperature": args.temperature, "top_p": args.top_p, "do_sample": args.do_sample, "max_new_tokens": args.max_new_tokens, "prompted": args.prompted})
67+
cur_prompt["outputs"] = []
68+
prompt_str = cur_prompt["prompt"]
69+
70+
total_tokens += len(generator.tokenizer.encode(output[0]["generated_text"]))
71+
cleaned_output = clean_output(output[0]["generated_text"], prompt_str)
72+
cur_prompt["outputs"].append(cleaned_output)
73+
74+
if idx % args.num_samples_per_prompt == args.num_samples_per_prompt - 1:
75+
responses.append(cur_prompt)
76+
77+
if idx != 0 and idx % args.num_samples_per_prompt == 0:
78+
print(f"Tokens per second: {total_tokens / (time.time() - start_time):.2f}")
79+
80+
end_time = time.time()
81+
tokens_per_second = total_tokens / (end_time - start_time)
82+
print(f"Generated {len(responses)} code samples in {end_time - start_time:.2f} seconds ({tokens_per_second:.2f} tokens per second)")
83+
84+
""" Save responses to JSON file """
85+
with open(args.output, 'w') as output_file:
86+
json.dump(responses, output_file, indent=4)

generate/utils.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# std imports
2+
from abc import ABC, abstractmethod
3+
4+
# tpl imports
5+
import torch
6+
from torch.utils.data import Dataset
7+
from transformers import StoppingCriteria
8+
9+
10+
class InferenceConfig(ABC):
11+
12+
def __init__(self, prompted : bool = False):
13+
self.prompted = prompted
14+
15+
@abstractmethod
16+
def get_dtype(self):
17+
pass
18+
19+
@abstractmethod
20+
def init_padding(self, tokenizer):
21+
pass
22+
23+
@abstractmethod
24+
def get_pad_token_id(self, tokenizer) -> int:
25+
pass
26+
27+
@abstractmethod
28+
def get_eos_token_id(self, tokenizer) -> int:
29+
pass
30+
31+
@abstractmethod
32+
def format_prompt(self, prompt : str) -> str:
33+
pass
34+
35+
36+
class StarCoderConfig(InferenceConfig):
37+
38+
def __init__(self, prompted : bool = False):
39+
super().__init__(prompted=prompted)
40+
41+
def get_dtype(self):
42+
return torch.float16
43+
44+
def init_padding(self, tokenizer):
45+
tokenizer.pad_token_id = tokenizer.eos_token_id # for batching
46+
tokenizer.padding_side = "left" # for decoder-only models
47+
48+
def get_pad_token_id(self, tokenizer) -> int:
49+
return tokenizer.eos_token_id
50+
51+
def get_eos_token_id(self, tokenizer) -> int:
52+
return None
53+
54+
def format_prompt(self, prompt : str) -> str:
55+
if self.prompted:
56+
return f"<filename>solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
57+
return prompt.strip()
58+
59+
class CodeLlamaConfig(InferenceConfig):
60+
61+
def __init__(self, prompted : bool = False):
62+
super().__init__(prompted=prompted)
63+
64+
def get_dtype(self):
65+
return torch.float16
66+
67+
def init_padding(self, tokenizer):
68+
tokenizer.pad_token_id = tokenizer.eos_token_id # for batching
69+
tokenizer.padding_side = "left" # for decoder-only models
70+
pass
71+
72+
def get_pad_token_id(self, tokenizer) -> int:
73+
return tokenizer.pad_token_id
74+
75+
def get_eos_token_id(self, tokenizer) -> int:
76+
return tokenizer.eos_token_id
77+
78+
def format_prompt(self, prompt : str) -> str:
79+
if self.prompted:
80+
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
81+
return prompt.strip()
82+
83+
84+
class PolyCoderConfig(InferenceConfig):
85+
86+
def __init__(self, prompted : bool = False):
87+
super().__init__(prompted=prompted)
88+
89+
def get_dtype(self):
90+
return torch.float16
91+
92+
def init_padding(self, tokenizer):
93+
tokenizer.pad_token_id = tokenizer.eos_token_id # for batching
94+
tokenizer.padding_side = "left" # for decoder-only models
95+
96+
def get_pad_token_id(self, tokenizer) -> int:
97+
return tokenizer.eos_token_id
98+
99+
def get_eos_token_id(self, tokenizer) -> int:
100+
return tokenizer.eos_token_id
101+
102+
def format_prompt(self, prompt : str) -> str:
103+
if self.prompted:
104+
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
105+
return prompt.strip()
106+
107+
108+
class PhindConfig(InferenceConfig):
109+
110+
def __init__(self, prompted : bool = False):
111+
super().__init__(prompted=prompted)
112+
113+
def get_dtype(self):
114+
return torch.float16
115+
116+
def init_padding(self, tokenizer):
117+
tokenizer.pad_token_id = tokenizer.eos_token_id # for batching
118+
tokenizer.padding_side = "left" # for decoder-only models
119+
120+
def get_pad_token_id(self, tokenizer) -> int:
121+
return tokenizer.eos_token_id
122+
123+
def get_eos_token_id(self, tokenizer) -> int:
124+
return tokenizer.eos_token_id
125+
126+
def format_prompt(self, prompt : str) -> str:
127+
if self.prompted:
128+
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
129+
return prompt.strip()
130+
131+
132+
def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
133+
if model_name == "bigcode/starcoderbase":
134+
return StarCoderConfig(**kwargs)
135+
elif model_name.startswith("codellama/CodeLlama-") and 'Instruct' not in model_name:
136+
return CodeLlamaConfig(**kwargs)
137+
elif model_name == "NinedayWang/PolyCoder-2.7B":
138+
return PolyCoderConfig(**kwargs)
139+
elif model_name == 'Phind/Phind-CodeLlama-34B-v2':
140+
return PhindConfig(**kwargs)
141+
else:
142+
raise ValueError(f"Unknown model name: {model_name}")
143+
144+
145+
def clean_output(output : str, prompt : str) -> str:
146+
""" Remove `prompt` from the begging of `output`.
147+
Also truncate at the end of the function definition (i.e. matching closing brace).
148+
"""
149+
# replace up to the end of the first instance of prompt
150+
prompt_loc = output.find(prompt)
151+
if prompt_loc == -1:
152+
raise ValueError(f"Prompt not found in output: {prompt}")
153+
output = output[prompt_loc + len(prompt):].strip()
154+
155+
# temporarily add opening brace to the beginning
156+
output = '{' + output
157+
158+
# find the matching brace to output[0]
159+
stack = []
160+
index = 0
161+
while index < len(output):
162+
token = output[index]
163+
if token == '{':
164+
stack.append(token)
165+
elif token == '}':
166+
stack.pop()
167+
if len(stack) == 0:
168+
break
169+
170+
index += 1
171+
172+
# truncate at the matching brace
173+
output = output[1:index+1]
174+
return output
175+
176+
class PromptDataset(Dataset):
177+
''' PyTorch dataset that simply wraps a list of strings. They do not have to have the same length.
178+
'''
179+
180+
def __init__(self, prompts):
181+
super().__init__()
182+
self.prompts_ = prompts
183+
184+
def __len__(self):
185+
return len(self.prompts_)
186+
187+
def __getitem__(self, idx):
188+
return self.prompts_[idx]
189+
190+
191+
def has_balanced_brackets(text : str, left_bracket : str = '{', right_bracket : str = '}') -> bool:
192+
''' Check if string has balanced brackets.
193+
modified from: https://stackoverflow.com/a/38834249/3769237
194+
195+
Arguments:
196+
text: string to check for balanced brackets in.
197+
left_bracket: left bracket to balance
198+
right_bracket: right bracket to balance
199+
200+
Returns:
201+
true if left_bracket and right_bracket are balanced
202+
'''
203+
stack = []
204+
balanced = True
205+
index = 0
206+
while index < len(text) and balanced:
207+
token = text[index]
208+
if token == left_bracket:
209+
stack.append(token)
210+
elif token == right_bracket:
211+
if len(stack) == 0:
212+
balanced = False
213+
else:
214+
stack.pop()
215+
216+
index += 1
217+
218+
return balanced and len(stack) == 0
219+
220+
221+
class BalancedBracketsCriteria(StoppingCriteria):
222+
''' extension of transformers' text-generation stopping criteria.
223+
Stops either when function is complete (i.e. { and } are balanced) or when max_length is surpassed, whichever
224+
happens first.
225+
226+
_Note:_ This is a slow stopping criteria, but it's much faster than continually running model inference when
227+
it does not need to be run anymore.
228+
'''
229+
230+
def __init__(self, max_length : int, tokenizer, left_bracket : str = '{', right_bracket : str = '}'):
231+
self.max_length = max_length
232+
self.tokenizer = tokenizer
233+
self.left_bracket = left_bracket
234+
self.right_bracket = right_bracket
235+
236+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
237+
if input_ids.shape[-1] > self.max_length:
238+
# already too long, early stop
239+
return True
240+
241+
# return true if {} are balanced i.e. the function is complete
242+
return all(
243+
has_balanced_brackets(
244+
self.tokenizer.decode(t),
245+
left_bracket=self.left_bracket,
246+
right_bracket=self.right_bracket
247+
) for t in input_ids)

0 commit comments

Comments
 (0)