1+ """ Use this script to gather the raw prompts from text files and export a
2+ single json in the correct format. The script also optionally checks for
3+ prompt validity and can do some preprocessing on the prompts.
4+
5+ author: Daniel Nichols
6+ date: October 2023
7+ """
8+ # std imports
9+ from abc import ABC , abstractmethod
10+ from argparse import ArgumentParser
11+ import json
12+ import os
13+ import re
14+ from typing import List
15+
16+
17+ class Prompt :
18+ def __init__ (self , model : str ):
19+ self .model = model
20+
21+ @abstractmethod
22+ def check_valid (self , prompt : str ):
23+ if prompt == "" :
24+ raise ValueError ("Prompt is empty" )
25+ if not prompt .endswith ("{" ):
26+ raise ValueError (f"Prompt { prompt } does not end with { '{' } " )
27+ for substr in ["Example" , "input:" , "output:" ]:
28+ self .must_contain (prompt , substr )
29+
30+ @abstractmethod
31+ def get_model_name_for_function_suffix (self ):
32+ raise NotImplementedError (f"Prompt.get_model_name_for_function_suffix() not implemented for { self .__class__ .__name__ } " )
33+
34+ def must_contain (self , prompt : str , substr : str ):
35+ if substr not in prompt :
36+ raise ValueError (f"Prompt '{ prompt } ' does not contain '{ substr } '" )
37+
38+ def append_to_function_name (self , prompt : str , suffix : str ) -> str :
39+ """ prompt is a c++ docstring and function header. Add a suffix to the
40+ function name.
41+ """
42+ lines = prompt .split ("\n " )
43+ if len (lines ) < 2 :
44+ raise ValueError (f"Prompt '{ prompt } ' does not have at least two lines" )
45+
46+ # get the function name
47+ header = lines [- 1 ]
48+ match = re .search (r"[^\s]+\s+(\w+)\s*\(" , header )
49+ if not match :
50+ raise ValueError (f"Could not find function name in '{ header } '" )
51+
52+ # append the suffix
53+ function_name = match .group (1 )
54+ new_function_name = function_name + suffix
55+ lines [- 1 ] = header .replace (function_name , new_function_name )
56+
57+ return "\n " .join (lines )
58+
59+ @abstractmethod
60+ def add_imports (self , prompt : str ) -> str :
61+ raise NotImplementedError (f"Prompt.add_imports() not implemented for { self .__class__ .__name__ } " )
62+
63+
64+ class SerialPrompt (Prompt ):
65+ def __init__ (self ):
66+ super ().__init__ ("serial" )
67+
68+ def check_valid (self , prompt : str ):
69+ super ().check_valid (prompt )
70+
71+ def get_model_name_for_function_suffix (self ):
72+ return "" # append nothing for serial
73+
74+ def add_imports (self , prompt : str ) -> str :
75+ return prompt
76+
77+ class OpenMPPrompt (Prompt ):
78+ def __init__ (self ):
79+ super ().__init__ ("omp" )
80+
81+ def check_valid (self , prompt : str ):
82+ super ().check_valid (prompt )
83+ self .must_contain (prompt , "OpenMP" )
84+
85+ def get_model_name_for_function_suffix (self ):
86+ return "OpenMP"
87+
88+ def add_imports (self , prompt : str ) -> str :
89+ return "#include <omp.h>\n \n " + prompt
90+
91+ class MPIPrompt (Prompt ):
92+
93+ def __init__ (self ):
94+ super ().__init__ ("mpi" )
95+
96+ def check_valid (self , prompt : str ):
97+ super ().check_valid (prompt )
98+ self .must_contain (prompt , "MPI" )
99+ self .must_contain (prompt , "initialized" )
100+
101+ def get_model_name_for_function_suffix (self ):
102+ return "MPI"
103+
104+ def add_imports (self , prompt : str ) -> str :
105+ return "#include <mpi.h>\n \n " + prompt
106+
107+ class MPIOpenMPPrompt (Prompt ):
108+
109+ def __init__ (self ):
110+ super ().__init__ ("mpi+omp" )
111+
112+ def check_valid (self , prompt : str ):
113+ super ().check_valid (prompt )
114+ for substr in ["MPI" , "OpenMP" , "initialized" ]:
115+ self .must_contain (prompt , substr )
116+
117+ def get_model_name_for_function_suffix (self ):
118+ return "MPIOpenMP"
119+
120+ def add_imports (self , prompt : str ) -> str :
121+ return "#include <mpi.h>\n #include <omp.h>\n \n " + prompt
122+
123+
124+ class KokkosPrompt (Prompt ):
125+
126+ def __init__ (self ):
127+ super ().__init__ ("kokkos" )
128+
129+ def check_valid (self , prompt : str ):
130+ super ().check_valid (prompt )
131+ self .must_contain (prompt , "Kokkos" )
132+ self .must_contain (prompt , "initialized" )
133+
134+ def get_model_name_for_function_suffix (self ):
135+ return "Kokkos"
136+
137+ def add_imports (self , prompt : str ) -> str :
138+ return "#include <Kokkos_Core.hpp>\n \n " + prompt # todo -- figure out if this is ok
139+
140+
141+ class CUDAPrompt (Prompt ):
142+
143+ def __init__ (self ):
144+ super ().__init__ ("cuda" )
145+
146+ def check_valid (self , prompt : str ):
147+ super ().check_valid (prompt )
148+ for substr in ["CUDA" , "thread" , "__global__" ]:
149+ self .must_contain (prompt , substr )
150+
151+ def get_model_name_for_function_suffix (self ):
152+ return "CUDA"
153+
154+ def add_imports (self , prompt : str ) -> str :
155+ return prompt
156+
157+
158+ class HIPPrompt (Prompt ):
159+
160+ def __init__ (self ):
161+ super ().__init__ ("hip" )
162+
163+ def check_valid (self , prompt : str ):
164+ super ().check_valid (prompt )
165+ for substr in ["HIP" , "thread" , "__global__" ]:
166+ self .must_contain (prompt , substr )
167+
168+ def get_model_name_for_function_suffix (self ):
169+ return "HIP"
170+
171+ def add_imports (self , prompt : str ) -> str :
172+ return prompt
173+
174+
175+ def get_args ():
176+ parser = ArgumentParser (description = __doc__ )
177+ parser .add_argument ("raw_prompts_root" , type = str , help = "path to root of raw prompts" )
178+ parser .add_argument ("-o" , "--output" , help = "path to output json; defaults to stdout if not provided" )
179+ parser .add_argument ("--function-suffix" , choices = ["parallel" , "model" , "none" ], default = "none" , help = "suffix to add to function names" )
180+ parser .add_argument ("--add-imports" , action = "store_true" , help = "add imports above prompt" )
181+ return parser .parse_args ()
182+
183+
184+ def parse_raw_prompt (
185+ type : str ,
186+ name : str ,
187+ fpath : os .PathLike ,
188+ function_suffix : str = "none" ,
189+ add_imports : bool = False
190+ ) -> List [dict ]:
191+ """ Parse the raw prompts from the text files. fpath points to a directory
192+ with a text prompt for each model. The text prompt should be named
193+ <model>.
194+ """
195+ prompt_paths = os .listdir (fpath )
196+ if set (prompt_paths ) != {"serial" , "omp" , "mpi" , "mpi+omp" , "kokkos" , "cuda" , "hip" }:
197+ raise ValueError (f"{ fpath } does not contain prompts for all models" )
198+
199+ parsers = {"serial" : SerialPrompt (), "omp" : OpenMPPrompt (),
200+ "mpi" : MPIPrompt (), "mpi+omp" : MPIOpenMPPrompt (),
201+ "kokkos" : KokkosPrompt (), "cuda" : CUDAPrompt (), "hip" : HIPPrompt ()}
202+
203+ prompts = []
204+ for model in prompt_paths :
205+ model_path = os .path .join (fpath , model )
206+ if not os .path .isfile (model_path ):
207+ raise ValueError (f"Expected { model_path } to be a file" )
208+
209+ # get the parser and read in the contents of the prompt
210+ parser = parsers [model ]
211+ with open (model_path , "r" ) as fp :
212+ prompt = fp .read ()
213+
214+ # check the prompt is valid
215+ parser .check_valid (prompt )
216+
217+ # add to the function name if necessary
218+ if function_suffix == "model" :
219+ suffix = parser .get_model_name_for_function_suffix ()
220+ prompt = parser .append_to_function_name (prompt , suffix )
221+ elif function_suffix == "parallel" :
222+ prompt = parser .append_to_function_name (prompt , "Parallel" )
223+
224+ # add imports if necessary
225+ if add_imports :
226+ prompt = parser .add_imports (prompt )
227+
228+ prompts .append ({
229+ "problem_type" : type ,
230+ "language" : "cpp" ,
231+ "name" : name ,
232+ "parallelism_model" : model ,
233+ "prompt" : prompt
234+ })
235+
236+ return prompts
237+
238+ def main ():
239+ args = get_args ()
240+
241+ # prompts root is structured as <prompt_type>/<prompt_name>/<model>
242+ all_prompts = []
243+ for type in os .listdir (args .raw_prompts_root ):
244+ type_path = os .path .join (args .raw_prompts_root , type )
245+ if not os .path .isdir (type_path ):
246+ raise ValueError (f"Expected { type_path } to be a directory" )
247+
248+ for prompt_name in os .listdir (type_path ):
249+ prompt_path = os .path .join (type_path , prompt_name )
250+ if not os .path .isdir (prompt_path ):
251+ raise ValueError (f"Expected { prompt_path } to be a directory" )
252+
253+ prompts = parse_raw_prompt (type , prompt_name , prompt_path ,
254+ function_suffix = args .function_suffix , add_imports = args .add_imports )
255+ all_prompts .extend (prompts )
256+
257+
258+ # output
259+ if args .output :
260+ with open (args .output , "w" ) as fp :
261+ json .dump (all_prompts , fp , indent = 4 )
262+ else :
263+ print (json .dumps (all_prompts , indent = 4 ))
264+
265+
266+
267+ if __name__ == '__main__' :
268+ main ()
0 commit comments