Skip to content

Commit 0a25b14

Browse files
authored
added sort and scan prompts (#5)
1 parent 91507ac commit 0a25b14

79 files changed

Lines changed: 1562 additions & 4 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

prompts/gather-raw-prompts.py

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
""" Use this script to gather the raw prompts from text files and export a
2+
single json in the correct format. The script also optionally checks for
3+
prompt validity and can do some preprocessing on the prompts.
4+
5+
author: Daniel Nichols
6+
date: October 2023
7+
"""
8+
# std imports
9+
from abc import ABC, abstractmethod
10+
from argparse import ArgumentParser
11+
import json
12+
import os
13+
import re
14+
from typing import List
15+
16+
17+
class Prompt:
18+
def __init__(self, model: str):
19+
self.model = model
20+
21+
@abstractmethod
22+
def check_valid(self, prompt: str):
23+
if prompt == "":
24+
raise ValueError("Prompt is empty")
25+
if not prompt.endswith("{"):
26+
raise ValueError(f"Prompt {prompt} does not end with {'{'}")
27+
for substr in ["Example", "input:", "output:"]:
28+
self.must_contain(prompt, substr)
29+
30+
@abstractmethod
31+
def get_model_name_for_function_suffix(self):
32+
raise NotImplementedError(f"Prompt.get_model_name_for_function_suffix() not implemented for {self.__class__.__name__}")
33+
34+
def must_contain(self, prompt: str, substr: str):
35+
if substr not in prompt:
36+
raise ValueError(f"Prompt '{prompt}' does not contain '{substr}'")
37+
38+
def append_to_function_name(self, prompt: str, suffix: str) -> str:
39+
""" prompt is a c++ docstring and function header. Add a suffix to the
40+
function name.
41+
"""
42+
lines = prompt.split("\n")
43+
if len(lines) < 2:
44+
raise ValueError(f"Prompt '{prompt}' does not have at least two lines")
45+
46+
# get the function name
47+
header = lines[-1]
48+
match = re.search(r"[^\s]+\s+(\w+)\s*\(", header)
49+
if not match:
50+
raise ValueError(f"Could not find function name in '{header}'")
51+
52+
# append the suffix
53+
function_name = match.group(1)
54+
new_function_name = function_name + suffix
55+
lines[-1] = header.replace(function_name, new_function_name)
56+
57+
return "\n".join(lines)
58+
59+
@abstractmethod
60+
def add_imports(self, prompt: str) -> str:
61+
raise NotImplementedError(f"Prompt.add_imports() not implemented for {self.__class__.__name__}")
62+
63+
64+
class SerialPrompt(Prompt):
65+
def __init__(self):
66+
super().__init__("serial")
67+
68+
def check_valid(self, prompt: str):
69+
super().check_valid(prompt)
70+
71+
def get_model_name_for_function_suffix(self):
72+
return "" # append nothing for serial
73+
74+
def add_imports(self, prompt: str) -> str:
75+
return prompt
76+
77+
class OpenMPPrompt(Prompt):
78+
def __init__(self):
79+
super().__init__("omp")
80+
81+
def check_valid(self, prompt: str):
82+
super().check_valid(prompt)
83+
self.must_contain(prompt, "OpenMP")
84+
85+
def get_model_name_for_function_suffix(self):
86+
return "OpenMP"
87+
88+
def add_imports(self, prompt: str) -> str:
89+
return "#include <omp.h>\n\n" + prompt
90+
91+
class MPIPrompt(Prompt):
92+
93+
def __init__(self):
94+
super().__init__("mpi")
95+
96+
def check_valid(self, prompt: str):
97+
super().check_valid(prompt)
98+
self.must_contain(prompt, "MPI")
99+
self.must_contain(prompt, "initialized")
100+
101+
def get_model_name_for_function_suffix(self):
102+
return "MPI"
103+
104+
def add_imports(self, prompt: str) -> str:
105+
return "#include <mpi.h>\n\n" + prompt
106+
107+
class MPIOpenMPPrompt(Prompt):
108+
109+
def __init__(self):
110+
super().__init__("mpi+omp")
111+
112+
def check_valid(self, prompt: str):
113+
super().check_valid(prompt)
114+
for substr in ["MPI", "OpenMP", "initialized"]:
115+
self.must_contain(prompt, substr)
116+
117+
def get_model_name_for_function_suffix(self):
118+
return "MPIOpenMP"
119+
120+
def add_imports(self, prompt: str) -> str:
121+
return "#include <mpi.h>\n#include <omp.h>\n\n" + prompt
122+
123+
124+
class KokkosPrompt(Prompt):
125+
126+
def __init__(self):
127+
super().__init__("kokkos")
128+
129+
def check_valid(self, prompt: str):
130+
super().check_valid(prompt)
131+
self.must_contain(prompt, "Kokkos")
132+
self.must_contain(prompt, "initialized")
133+
134+
def get_model_name_for_function_suffix(self):
135+
return "Kokkos"
136+
137+
def add_imports(self, prompt: str) -> str:
138+
return "#include <Kokkos_Core.hpp>\n\n" + prompt # todo -- figure out if this is ok
139+
140+
141+
class CUDAPrompt(Prompt):
142+
143+
def __init__(self):
144+
super().__init__("cuda")
145+
146+
def check_valid(self, prompt: str):
147+
super().check_valid(prompt)
148+
for substr in ["CUDA", "thread", "__global__"]:
149+
self.must_contain(prompt, substr)
150+
151+
def get_model_name_for_function_suffix(self):
152+
return "CUDA"
153+
154+
def add_imports(self, prompt: str) -> str:
155+
return prompt
156+
157+
158+
class HIPPrompt(Prompt):
159+
160+
def __init__(self):
161+
super().__init__("hip")
162+
163+
def check_valid(self, prompt: str):
164+
super().check_valid(prompt)
165+
for substr in ["HIP", "thread", "__global__"]:
166+
self.must_contain(prompt, substr)
167+
168+
def get_model_name_for_function_suffix(self):
169+
return "HIP"
170+
171+
def add_imports(self, prompt: str) -> str:
172+
return prompt
173+
174+
175+
def get_args():
176+
parser = ArgumentParser(description=__doc__)
177+
parser.add_argument("raw_prompts_root", type=str, help="path to root of raw prompts")
178+
parser.add_argument("-o", "--output", help="path to output json; defaults to stdout if not provided")
179+
parser.add_argument("--function-suffix", choices=["parallel", "model", "none"], default="none", help="suffix to add to function names")
180+
parser.add_argument("--add-imports", action="store_true", help="add imports above prompt")
181+
return parser.parse_args()
182+
183+
184+
def parse_raw_prompt(
185+
type: str,
186+
name: str,
187+
fpath: os.PathLike,
188+
function_suffix: str = "none",
189+
add_imports: bool = False
190+
) -> List[dict]:
191+
""" Parse the raw prompts from the text files. fpath points to a directory
192+
with a text prompt for each model. The text prompt should be named
193+
<model>.
194+
"""
195+
prompt_paths = os.listdir(fpath)
196+
if set(prompt_paths) != {"serial", "omp", "mpi", "mpi+omp", "kokkos", "cuda", "hip"}:
197+
raise ValueError(f"{fpath} does not contain prompts for all models")
198+
199+
parsers = {"serial": SerialPrompt(), "omp": OpenMPPrompt(),
200+
"mpi": MPIPrompt(), "mpi+omp": MPIOpenMPPrompt(),
201+
"kokkos": KokkosPrompt(), "cuda": CUDAPrompt(), "hip": HIPPrompt()}
202+
203+
prompts = []
204+
for model in prompt_paths:
205+
model_path = os.path.join(fpath, model)
206+
if not os.path.isfile(model_path):
207+
raise ValueError(f"Expected {model_path} to be a file")
208+
209+
# get the parser and read in the contents of the prompt
210+
parser = parsers[model]
211+
with open(model_path, "r") as fp:
212+
prompt = fp.read()
213+
214+
# check the prompt is valid
215+
parser.check_valid(prompt)
216+
217+
# add to the function name if necessary
218+
if function_suffix == "model":
219+
suffix = parser.get_model_name_for_function_suffix()
220+
prompt = parser.append_to_function_name(prompt, suffix)
221+
elif function_suffix == "parallel":
222+
prompt = parser.append_to_function_name(prompt, "Parallel")
223+
224+
# add imports if necessary
225+
if add_imports:
226+
prompt = parser.add_imports(prompt)
227+
228+
prompts.append({
229+
"problem_type": type,
230+
"language": "cpp",
231+
"name": name,
232+
"parallelism_model": model,
233+
"prompt": prompt
234+
})
235+
236+
return prompts
237+
238+
def main():
239+
args = get_args()
240+
241+
# prompts root is structured as <prompt_type>/<prompt_name>/<model>
242+
all_prompts = []
243+
for type in os.listdir(args.raw_prompts_root):
244+
type_path = os.path.join(args.raw_prompts_root, type)
245+
if not os.path.isdir(type_path):
246+
raise ValueError(f"Expected {type_path} to be a directory")
247+
248+
for prompt_name in os.listdir(type_path):
249+
prompt_path = os.path.join(type_path, prompt_name)
250+
if not os.path.isdir(prompt_path):
251+
raise ValueError(f"Expected {prompt_path} to be a directory")
252+
253+
prompts = parse_raw_prompt(type, prompt_name, prompt_path,
254+
function_suffix=args.function_suffix, add_imports=args.add_imports)
255+
all_prompts.extend(prompts)
256+
257+
258+
# output
259+
if args.output:
260+
with open(args.output, "w") as fp:
261+
json.dump(all_prompts, fp, indent=4)
262+
else:
263+
print(json.dumps(all_prompts, indent=4))
264+
265+
266+
267+
if __name__ == '__main__':
268+
main()

0 commit comments

Comments
 (0)