@@ -66,6 +66,52 @@ def find_matching_brace_index(code: str, open_brace_index: int) -> int:
6666 raise ValueError ("Unmatched opening brace" )
6767
6868
69+ def clean_instruct_output (output : str , prompt : str , response_tag : str ) -> str :
70+ """ Clean LLM output to find code solution. The output should be in a ```c++ ``` code block. If there are
71+ multiple, then it tries to find the block with the function definition (as contained in the prompt).
72+ The code block itself may include the function definition and body OR just the body. This will try
73+ to parse both.
74+ """
75+ # 0. replace up to the end of the first instance of prompt
76+ prompt_loc = output .find (response_tag )
77+ if prompt_loc == - 1 :
78+ raise ValueError (f"Prompt not found in output: { prompt } " )
79+ output = output [prompt_loc + len (response_tag ):].strip ()
80+
81+ # 1. Find all code blocks enclosed in triple backticks with "c++" language tag
82+ code_blocks = re .findall (r"```\n(.*?)\n```" , output , flags = re .DOTALL )
83+ code_blocks = [block .removeprefix ("```" ).removeprefix ("cpp" ).removeprefix ('c++' ).removesuffix ('```' ) for block in code_blocks ]
84+
85+ # 2. Prioritize code blocks containing the function definition from the prompt
86+ sub_prompt = prompt .rstrip ().removesuffix (response_tag ).rstrip ().removesuffix ("```" ).split ("```" )[- 1 ]
87+ function_name = get_function_name (sub_prompt , "cuda" if "__global__" in sub_prompt else "serial" )
88+ prioritized_blocks = [block for block in code_blocks if function_name in block ]
89+
90+ # 3. Choose the first block if multiple match, or any block if none match
91+ if len (code_blocks ) > 0 :
92+ selected_block = prioritized_blocks [0 ] if prioritized_blocks else code_blocks [0 ]
93+ else :
94+ if '```' in output : # starts with ```c++ but it didn't finish
95+ code_idx = output .find ('```' )
96+ selected_block = output [code_idx :].removeprefix ('```' )
97+ else :
98+ selected_block = output
99+
100+ # 4. Handle cases where the block contains only the function body
101+ if function_name not in selected_block :
102+ return selected_block
103+ else :
104+ function_start_index = selected_block .index (function_name )
105+ open_brace_index = selected_block .find ("{" , function_start_index )
106+ try :
107+ close_brace_index = find_matching_brace_index (selected_block , open_brace_index )
108+ except ValueError :
109+ close_brace_index = len (selected_block )
110+
111+ function_body = selected_block [open_brace_index + 1 : close_brace_index ]
112+ return function_body + "}"
113+
114+
69115class InferenceConfig (ABC ):
70116
71117 def __init__ (self , prompted : bool = False ):
@@ -290,50 +336,70 @@ def format_prompt(self, prompt : str) -> str:
290336 return prompt .strip ()
291337
292338 def clean_output (self , output : str , prompt : str ) -> str :
293- """ Clean LLM output to find code solution. The output should be in a ```c++ ``` code block. If there are
294- multiple, then it tries to find the block with the function definition (as contained in the prompt).
295- The code block itself may include the function definition and body OR just the body. This will try
296- to parse both.
297- """
298- # 0. replace up to the end of the first instance of prompt
299- prompt_loc = output .find ("@@ Response" )
300- if prompt_loc == - 1 :
301- raise ValueError (f"Prompt not found in output: { prompt } " )
302- output = output [prompt_loc + len ("@@ Response" ):].strip ()
303-
304- # 1. Find all code blocks enclosed in triple backticks with "c++" language tag
305- code_blocks = re .findall (r"```c\+\+\n(.*?)\n```" , output , flags = re .DOTALL )
306- code_blocks = [block .lstrip ('```c++' ).rstrip ('```' ) for block in code_blocks ]
307-
308- # 2. Prioritize code blocks containing the function definition from the prompt
309- sub_prompt = prompt .rstrip ().removesuffix ("@@ Response" ).rstrip ().removesuffix ("```" ).split ("```" )[- 1 ]
310- function_name = get_function_name (sub_prompt , "cuda" if "__global__" in sub_prompt else "serial" )
311- prioritized_blocks = [block for block in code_blocks if function_name in block ]
312-
313- # 3. Choose the first block if multiple match, or any block if none match
314- if len (code_blocks ) > 0 :
315- selected_block = prioritized_blocks [0 ] if prioritized_blocks else code_blocks [0 ]
316- else :
317- if '```c++' in output : # starts with ```c++ but it didn't finish
318- code_idx = output .find ('```c++' )
319- selected_block = output [code_idx :].removeprefix ('```c++' )
320- else :
321- selected_block = output
339+ return clean_instruct_output (output , prompt , "@@ Response" )
322340
323- # 4. Handle cases where the block contains only the function body
324- if function_name not in selected_block :
325- return selected_block
326- else :
327- function_start_index = selected_block .index (function_name )
328- open_brace_index = selected_block .find ("{" , function_start_index )
329- try :
330- close_brace_index = find_matching_brace_index (selected_block , open_brace_index )
331- except ValueError :
332- close_brace_index = len (selected_block )
333341
334- function_body = selected_block [open_brace_index + 1 : close_brace_index ]
335- return function_body + "}"
342+ class DeepSeekBaseConfig (InferenceConfig ):
343+
344+ def __init__ (self , prompted : bool = False ):
345+ super ().__init__ (prompted = prompted )
346+
347+ def get_dtype (self ):
348+ return torch .bfloat16
349+
350+ def init_padding (self , tokenizer ):
351+ tokenizer .pad_token_id = tokenizer .eos_token_id # for batching
352+ tokenizer .padding_side = "left" # for decoder-only models
336353
354+ def get_pad_token_id (self , tokenizer ) -> int :
355+ return tokenizer .pad_token_id
356+
357+ def get_eos_token_id (self , tokenizer ) -> int :
358+ return tokenizer .eos_token_id
359+
360+ def trust_remote_code (self ) -> bool :
361+ return False
362+
363+ def format_prompt (self , prompt : str ) -> str :
364+ if self .prompted :
365+ return f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
366+ return prompt .strip ()
367+
368+ def clean_output (self , output : str , prompt : str ) -> str :
369+ return clean_output (output , prompt )
370+
371+
372+ class InstructConfig (InferenceConfig ):
373+
374+ def __init__ (self , prompted : bool = False , instruction_tag : str = "### Instruction" , response_tag : str = "### Response" ):
375+ super ().__init__ (prompted = prompted )
376+ self .instruction_tag = instruction_tag
377+ self .response_tag = response_tag
378+
379+ def get_dtype (self ):
380+ return torch .bfloat16
381+
382+ def init_padding (self , tokenizer ):
383+ tokenizer .pad_token_id = tokenizer .eos_token_id # for batching
384+ tokenizer .padding_side = "left" # for decoder-only models
385+
386+ def get_pad_token_id (self , tokenizer ) -> int :
387+ return tokenizer .pad_token_id
388+
389+ def get_eos_token_id (self , tokenizer ) -> int :
390+ return tokenizer .eos_token_id
391+
392+ def trust_remote_code (self ) -> bool :
393+ return False
394+
395+ def format_prompt (self , prompt : str ) -> str :
396+ function_name = get_function_name (prompt , "cuda" if "__global__" in prompt else "serial" )
397+ prompt = f"Complete the following c++ function.\n ```c++{ prompt .strip ()} ```\n Write only the function { function_name } and no other code. Enclose your solution in ```c++ and ```."
398+ prompt = f"{ self .instruction_tag } \n { prompt } \n { self .response_tag } \n "
399+ return prompt .strip ()
400+
401+ def clean_output (self , output : str , prompt : str ) -> str :
402+ return clean_instruct_output (output , prompt , self .response_tag )
337403
338404def get_inference_config (model_name : str , ** kwargs ) -> InferenceConfig :
339405 if model_name == "bigcode/starcoderbase" :
@@ -350,6 +416,12 @@ def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
350416 return ReplitConfig (** kwargs )
351417 elif model_name .startswith ('ise-uiuc/Magicoder' ):
352418 return MagicoderConfig (** kwargs )
419+ elif model_name in ['deepseek-ai/deepseek-coder-6.7b-base' , 'deepseek-ai/deepseek-coder-7b-base-v1.5' ]:
420+ return DeepSeekBaseConfig (** kwargs )
421+ elif model_name .startswith ('hpcgroup/hpc-coder-v2' ):
422+ return InstructConfig (instruction_tag = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n \n ### Instruction:' , response_tag = '### Response:' , ** kwargs )
423+ elif model_name .startswith ('hpcgroup/rlpf' ):
424+ return InstructConfig (instruction_tag = '### Instruction' , response_tag = '### Response' , ** kwargs )
353425 else :
354426 raise ValueError (f"Unknown model name: { model_name } " )
355427
0 commit comments