11# std imports
22from abc import ABC , abstractmethod
3+ import re
34
45# tpl imports
56import torch
67from torch .utils .data import Dataset
78from transformers import StoppingCriteria
89
910
11+ def clean_output (output : str , prompt : str ) -> str :
12+ """ Remove `prompt` from the begging of `output`.
13+ Also truncate at the end of the function definition (i.e. matching closing brace).
14+ """
15+ # replace up to the end of the first instance of prompt
16+ prompt_loc = output .find (prompt )
17+ if prompt_loc == - 1 :
18+ raise ValueError (f"Prompt not found in output: { prompt } " )
19+ output = output [prompt_loc + len (prompt ):].strip ()
20+
21+ # temporarily add opening brace to the beginning
22+ output = '{' + output
23+
24+ # find the matching brace to output[0]
25+ stack = []
26+ index = 0
27+ while index < len (output ):
28+ token = output [index ]
29+ if token == '{' :
30+ stack .append (token )
31+ elif token == '}' :
32+ stack .pop ()
33+ if len (stack ) == 0 :
34+ break
35+
36+ index += 1
37+
38+ # truncate at the matching brace
39+ output = output [1 :index + 1 ]
40+ return output
41+
42+ GPU_FUNCTION_NAME_PATTERN = re .compile (r"__global__ void ([a-zA-Z0-9_]+)\(" )
43+ CPU_FUNCTION_NAME_PATTERN = re .compile (r"\s*[a-zA-Z_]+ ([a-zA-Z0-9_]+)\(" )
44+ def get_function_name (prompt : str , execution_model : str ) -> str :
45+ if execution_model in ['cuda' , 'hip' ]:
46+ match = GPU_FUNCTION_NAME_PATTERN .match (prompt .splitlines ()[- 1 ])
47+ else :
48+ match = CPU_FUNCTION_NAME_PATTERN .match (prompt .splitlines ()[- 1 ])
49+ if match is None :
50+ raise ValueError (f"Could not find function name in prompt: { prompt } " )
51+ return match .group (1 )
52+
53+
54+ def find_matching_brace_index (code : str , open_brace_index : int ) -> int :
55+ """Finds the index of the closing brace that matches the opening brace at the given index."""
56+
57+ brace_count = 1
58+ for i in range (open_brace_index + 1 , len (code )):
59+ if code [i ] == "{" :
60+ brace_count += 1
61+ elif code [i ] == "}" :
62+ brace_count -= 1
63+ if brace_count == 0 :
64+ return i
65+
66+ raise ValueError ("Unmatched opening brace" )
67+
68+
1069class InferenceConfig (ABC ):
1170
1271 def __init__ (self , prompted : bool = False ):
@@ -36,6 +95,10 @@ def trust_remote_code(self) -> bool:
3695 def format_prompt (self , prompt : str ) -> str :
3796 pass
3897
98+ @abstractmethod
99+ def clean_output (self , output : str , prompt : str ) -> str :
100+ pass
101+
39102
40103class StarCoderConfig (InferenceConfig ):
41104
@@ -63,6 +126,9 @@ def format_prompt(self, prompt : str) -> str:
63126 return f"<filename>solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
64127 return prompt .strip ()
65128
129+ def clean_output (self , output : str , prompt : str ) -> str :
130+ return clean_output (output , prompt )
131+
66132class CodeLlamaConfig (InferenceConfig ):
67133
68134 def __init__ (self , prompted : bool = False ):
@@ -90,6 +156,8 @@ def format_prompt(self, prompt : str) -> str:
90156 return f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
91157 return prompt .strip ()
92158
159+ def clean_output (self , output : str , prompt : str ) -> str :
160+ return clean_output (output , prompt )
93161
94162class PolyCoderConfig (InferenceConfig ):
95163
@@ -116,6 +184,9 @@ def format_prompt(self, prompt : str) -> str:
116184 if self .prompted :
117185 return f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
118186 return prompt .strip ()
187+
188+ def clean_output (self , output : str , prompt : str ) -> str :
189+ return clean_output (output , prompt )
119190
120191
121192class PhindConfig (InferenceConfig ):
@@ -144,6 +215,9 @@ def format_prompt(self, prompt : str) -> str:
144215 return f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
145216 return prompt .strip ()
146217
218+ def clean_output (self , output : str , prompt : str ) -> str :
219+ return clean_output (output , prompt )
220+
147221
148222class ReplitConfig (InferenceConfig ):
149223
@@ -174,6 +248,92 @@ def format_prompt(self, prompt : str) -> str:
174248 return f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
175249 return prompt .strip ()
176250
251+ def clean_output (self , output : str , prompt : str ) -> str :
252+ return clean_output (output , prompt )
253+
254+
255+ class MagicoderConfig (InferenceConfig ):
256+
257+ PROMPT_TEMPLATE = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
258+
259+ @@ Instruction
260+ {instruction}
261+
262+ @@ Response
263+ """
264+
265+ def __init__ (self , prompted : bool = False ):
266+ super ().__init__ (prompted = prompted )
267+
268+ def get_dtype (self ):
269+ return torch .bfloat16
270+
271+ def init_padding (self , tokenizer ):
272+ tokenizer .pad_token_id = tokenizer .eos_token_id # for batching
273+ tokenizer .padding_side = "left" # for decoder-only models
274+ pass
275+
276+ def get_pad_token_id (self , tokenizer ) -> int :
277+ return tokenizer .pad_token_id
278+
279+ def get_eos_token_id (self , tokenizer ) -> int :
280+ return tokenizer .eos_token_id
281+
282+ def trust_remote_code (self ) -> bool :
283+ return False
284+
285+ def format_prompt (self , prompt : str ) -> str :
286+ if self .prompted :
287+ function_name = get_function_name (prompt , "cuda" if "__global__" in prompt else "serial" )
288+ prompt = f"Complete the following c++ function.\n ```c++{ prompt .strip ()} ```\n Write only the function { function_name } and no other code. Enclose your solution in ```c++ and ```."
289+ return self .PROMPT_TEMPLATE .format (instruction = prompt )
290+ return prompt .strip ()
291+
292+ def clean_output (self , output : str , prompt : str ) -> str :
293+ """ Clean LLM output to find code solution. The output should be in a ```c++ ``` code block. If there are
294+ multiple, then it tries to find the block with the function definition (as contained in the prompt).
295+ The code block itself may include the function definition and body OR just the body. This will try
296+ to parse both.
297+ """
298+ # 0. replace up to the end of the first instance of prompt
299+ prompt_loc = output .find ("@@ Response" )
300+ if prompt_loc == - 1 :
301+ raise ValueError (f"Prompt not found in output: { prompt } " )
302+ output = output [prompt_loc + len ("@@ Response" ):].strip ()
303+
304+ # 1. Find all code blocks enclosed in triple backticks with "c++" language tag
305+ code_blocks = re .findall (r"```c\+\+\n(.*?)\n```" , output , flags = re .DOTALL )
306+ code_blocks = [block .lstrip ('```c++' ).rstrip ('```' ) for block in code_blocks ]
307+
308+ # 2. Prioritize code blocks containing the function definition from the prompt
309+ sub_prompt = prompt .rstrip ().removesuffix ("@@ Response" ).rstrip ().removesuffix ("```" ).split ("```" )[- 1 ]
310+ function_name = get_function_name (sub_prompt , "cuda" if "__global__" in sub_prompt else "serial" )
311+ prioritized_blocks = [block for block in code_blocks if function_name in block ]
312+
313+ # 3. Choose the first block if multiple match, or any block if none match
314+ if len (code_blocks ) > 0 :
315+ selected_block = prioritized_blocks [0 ] if prioritized_blocks else code_blocks [0 ]
316+ else :
317+ if '```c++' in output : # starts with ```c++ but it didn't finish
318+ code_idx = output .find ('```c++' )
319+ selected_block = output [code_idx :].removeprefix ('```c++' )
320+ else :
321+ selected_block = output
322+
323+ # 4. Handle cases where the block contains only the function body
324+ if function_name not in selected_block :
325+ return selected_block
326+ else :
327+ function_start_index = selected_block .index (function_name )
328+ open_brace_index = selected_block .find ("{" , function_start_index )
329+ try :
330+ close_brace_index = find_matching_brace_index (selected_block , open_brace_index )
331+ except ValueError :
332+ close_brace_index = len (selected_block )
333+
334+ function_body = selected_block [open_brace_index + 1 : close_brace_index ]
335+ return function_body + "}"
336+
177337
178338def get_inference_config (model_name : str , ** kwargs ) -> InferenceConfig :
179339 if model_name == "bigcode/starcoderbase" :
@@ -186,41 +346,12 @@ def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
186346 return PhindConfig (** kwargs )
187347 elif model_name == 'replit/replit-code-v1_5-3b' :
188348 return ReplitConfig (** kwargs )
349+ elif model_name .startswith ('ise-uiuc/Magicoder' ):
350+ return MagicoderConfig (** kwargs )
189351 else :
190352 raise ValueError (f"Unknown model name: { model_name } " )
191353
192354
193- def clean_output (output : str , prompt : str ) -> str :
194- """ Remove `prompt` from the begging of `output`.
195- Also truncate at the end of the function definition (i.e. matching closing brace).
196- """
197- # replace up to the end of the first instance of prompt
198- prompt_loc = output .find (prompt )
199- if prompt_loc == - 1 :
200- raise ValueError (f"Prompt not found in output: { prompt } " )
201- output = output [prompt_loc + len (prompt ):].strip ()
202-
203- # temporarily add opening brace to the beginning
204- output = '{' + output
205-
206- # find the matching brace to output[0]
207- stack = []
208- index = 0
209- while index < len (output ):
210- token = output [index ]
211- if token == '{' :
212- stack .append (token )
213- elif token == '}' :
214- stack .pop ()
215- if len (stack ) == 0 :
216- break
217-
218- index += 1
219-
220- # truncate at the matching brace
221- output = output [1 :index + 1 ]
222- return output
223-
224355class PromptDataset (Dataset ):
225356 ''' PyTorch dataset that simply wraps a list of strings. They do not have to have the same length.
226357 '''
0 commit comments