Skip to content

Commit a542ccf

Browse files
authored
handle hpc-coder-v2 in generation scripts (#28)
1 parent 5e0479f commit a542ccf

2 files changed

Lines changed: 115 additions & 41 deletions

File tree

generate/generate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,13 @@
130130
cur_prompt = prompt.copy()
131131
cur_prompt.update({"temperature": args.temperature, "top_p": args.top_p, "do_sample": args.do_sample, "max_new_tokens": args.max_new_tokens, "prompted": args.prompted})
132132
cur_prompt["outputs"] = []
133+
cur_prompt["raw_outputs"] = []
133134
prompt_str = cur_prompt["prompt"]
134135

135136
total_tokens += len(generator.tokenizer.encode(output[0]["generated_text"]))
136137
cleaned_output = inference_config.clean_output(output[0]["generated_text"], prompt_str)
137138
cur_prompt["outputs"].append(cleaned_output)
139+
cur_prompt["raw_outputs"].append(output[0]["generated_text"])
138140

139141
if idx % args.num_samples_per_prompt == args.num_samples_per_prompt - 1:
140142
responses.append(cur_prompt)

generate/utils.py

Lines changed: 113 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,52 @@ def find_matching_brace_index(code: str, open_brace_index: int) -> int:
6666
raise ValueError("Unmatched opening brace")
6767

6868

69+
def clean_instruct_output(output: str, prompt: str, response_tag: str) -> str:
70+
""" Clean LLM output to find code solution. The output should be in a ```c++ ``` code block. If there are
71+
multiple, then it tries to find the block with the function definition (as contained in the prompt).
72+
The code block itself may include the function definition and body OR just the body. This will try
73+
to parse both.
74+
"""
75+
# 0. replace up to the end of the first instance of prompt
76+
prompt_loc = output.find(response_tag)
77+
if prompt_loc == -1:
78+
raise ValueError(f"Prompt not found in output: {prompt}")
79+
output = output[prompt_loc + len(response_tag):].strip()
80+
81+
# 1. Find all code blocks enclosed in triple backticks with "c++" language tag
82+
code_blocks = re.findall(r"```\n(.*?)\n```", output, flags=re.DOTALL)
83+
code_blocks = [block.removeprefix("```").removeprefix("cpp").removeprefix('c++').removesuffix('```') for block in code_blocks]
84+
85+
# 2. Prioritize code blocks containing the function definition from the prompt
86+
sub_prompt = prompt.rstrip().removesuffix(response_tag).rstrip().removesuffix("```").split("```")[-1]
87+
function_name = get_function_name(sub_prompt, "cuda" if "__global__" in sub_prompt else "serial")
88+
prioritized_blocks = [block for block in code_blocks if function_name in block]
89+
90+
# 3. Choose the first block if multiple match, or any block if none match
91+
if len(code_blocks) > 0:
92+
selected_block = prioritized_blocks[0] if prioritized_blocks else code_blocks[0]
93+
else:
94+
if '```' in output: # starts with ```c++ but it didn't finish
95+
code_idx = output.find('```')
96+
selected_block = output[code_idx:].removeprefix('```')
97+
else:
98+
selected_block = output
99+
100+
# 4. Handle cases where the block contains only the function body
101+
if function_name not in selected_block:
102+
return selected_block
103+
else:
104+
function_start_index = selected_block.index(function_name)
105+
open_brace_index = selected_block.find("{", function_start_index)
106+
try:
107+
close_brace_index = find_matching_brace_index(selected_block, open_brace_index)
108+
except ValueError:
109+
close_brace_index = len(selected_block)
110+
111+
function_body = selected_block[open_brace_index + 1 : close_brace_index]
112+
return function_body + "}"
113+
114+
69115
class InferenceConfig(ABC):
70116

71117
def __init__(self, prompted : bool = False):
@@ -290,50 +336,70 @@ def format_prompt(self, prompt : str) -> str:
290336
return prompt.strip()
291337

292338
def clean_output(self, output: str, prompt: str) -> str:
293-
""" Clean LLM output to find code solution. The output should be in a ```c++ ``` code block. If there are
294-
multiple, then it tries to find the block with the function definition (as contained in the prompt).
295-
The code block itself may include the function definition and body OR just the body. This will try
296-
to parse both.
297-
"""
298-
# 0. replace up to the end of the first instance of prompt
299-
prompt_loc = output.find("@@ Response")
300-
if prompt_loc == -1:
301-
raise ValueError(f"Prompt not found in output: {prompt}")
302-
output = output[prompt_loc + len("@@ Response"):].strip()
303-
304-
# 1. Find all code blocks enclosed in triple backticks with "c++" language tag
305-
code_blocks = re.findall(r"```c\+\+\n(.*?)\n```", output, flags=re.DOTALL)
306-
code_blocks = [block.lstrip('```c++').rstrip('```') for block in code_blocks]
307-
308-
# 2. Prioritize code blocks containing the function definition from the prompt
309-
sub_prompt = prompt.rstrip().removesuffix("@@ Response").rstrip().removesuffix("```").split("```")[-1]
310-
function_name = get_function_name(sub_prompt, "cuda" if "__global__" in sub_prompt else "serial")
311-
prioritized_blocks = [block for block in code_blocks if function_name in block]
312-
313-
# 3. Choose the first block if multiple match, or any block if none match
314-
if len(code_blocks) > 0:
315-
selected_block = prioritized_blocks[0] if prioritized_blocks else code_blocks[0]
316-
else:
317-
if '```c++' in output: # starts with ```c++ but it didn't finish
318-
code_idx = output.find('```c++')
319-
selected_block = output[code_idx:].removeprefix('```c++')
320-
else:
321-
selected_block = output
339+
return clean_instruct_output(output, prompt, "@@ Response")
322340

323-
# 4. Handle cases where the block contains only the function body
324-
if function_name not in selected_block:
325-
return selected_block
326-
else:
327-
function_start_index = selected_block.index(function_name)
328-
open_brace_index = selected_block.find("{", function_start_index)
329-
try:
330-
close_brace_index = find_matching_brace_index(selected_block, open_brace_index)
331-
except ValueError:
332-
close_brace_index = len(selected_block)
333341

334-
function_body = selected_block[open_brace_index + 1 : close_brace_index]
335-
return function_body + "}"
342+
class DeepSeekBaseConfig(InferenceConfig):
343+
344+
def __init__(self, prompted : bool = False):
345+
super().__init__(prompted=prompted)
346+
347+
def get_dtype(self):
348+
return torch.bfloat16
349+
350+
def init_padding(self, tokenizer):
351+
tokenizer.pad_token_id = tokenizer.eos_token_id # for batching
352+
tokenizer.padding_side = "left" # for decoder-only models
336353

354+
def get_pad_token_id(self, tokenizer) -> int:
355+
return tokenizer.pad_token_id
356+
357+
def get_eos_token_id(self, tokenizer) -> int:
358+
return tokenizer.eos_token_id
359+
360+
def trust_remote_code(self) -> bool:
361+
return False
362+
363+
def format_prompt(self, prompt : str) -> str:
364+
if self.prompted:
365+
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
366+
return prompt.strip()
367+
368+
def clean_output(self, output: str, prompt: str) -> str:
369+
return clean_output(output, prompt)
370+
371+
372+
class InstructConfig(InferenceConfig):
373+
374+
def __init__(self, prompted : bool = False, instruction_tag : str = "### Instruction", response_tag : str = "### Response"):
375+
super().__init__(prompted=prompted)
376+
self.instruction_tag = instruction_tag
377+
self.response_tag = response_tag
378+
379+
def get_dtype(self):
380+
return torch.bfloat16
381+
382+
def init_padding(self, tokenizer):
383+
tokenizer.pad_token_id = tokenizer.eos_token_id # for batching
384+
tokenizer.padding_side = "left" # for decoder-only models
385+
386+
def get_pad_token_id(self, tokenizer) -> int:
387+
return tokenizer.pad_token_id
388+
389+
def get_eos_token_id(self, tokenizer) -> int:
390+
return tokenizer.eos_token_id
391+
392+
def trust_remote_code(self) -> bool:
393+
return False
394+
395+
def format_prompt(self, prompt : str) -> str:
396+
function_name = get_function_name(prompt, "cuda" if "__global__" in prompt else "serial")
397+
prompt = f"Complete the following c++ function.\n```c++{prompt.strip()}```\nWrite only the function {function_name} and no other code. Enclose your solution in ```c++ and ```."
398+
prompt = f"{self.instruction_tag}\n{prompt}\n{self.response_tag}\n"
399+
return prompt.strip()
400+
401+
def clean_output(self, output: str, prompt: str) -> str:
402+
return clean_instruct_output(output, prompt, self.response_tag)
337403

338404
def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
339405
if model_name == "bigcode/starcoderbase":
@@ -350,6 +416,12 @@ def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
350416
return ReplitConfig(**kwargs)
351417
elif model_name.startswith('ise-uiuc/Magicoder'):
352418
return MagicoderConfig(**kwargs)
419+
elif model_name in ['deepseek-ai/deepseek-coder-6.7b-base', 'deepseek-ai/deepseek-coder-7b-base-v1.5']:
420+
return DeepSeekBaseConfig(**kwargs)
421+
elif model_name.startswith('hpcgroup/hpc-coder-v2'):
422+
return InstructConfig(instruction_tag='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:', response_tag='### Response:', **kwargs)
423+
elif model_name.startswith('hpcgroup/rlpf'):
424+
return InstructConfig(instruction_tag='### Instruction', response_tag='### Response', **kwargs)
353425
else:
354426
raise ValueError(f"Unknown model name: {model_name}")
355427

0 commit comments

Comments
 (0)