1+ # std imports
2+ from abc import ABC , abstractmethod
3+
4+ # tpl imports
5+ import torch
6+ from torch .utils .data import Dataset
7+ from transformers import StoppingCriteria
8+
9+
10+ class InferenceConfig (ABC ):
11+
12+ def __init__ (self , prompted : bool = False ):
13+ self .prompted = prompted
14+
15+ @abstractmethod
16+ def get_dtype (self ):
17+ pass
18+
19+ @abstractmethod
20+ def init_padding (self , tokenizer ):
21+ pass
22+
23+ @abstractmethod
24+ def get_pad_token_id (self , tokenizer ) -> int :
25+ pass
26+
27+ @abstractmethod
28+ def get_eos_token_id (self , tokenizer ) -> int :
29+ pass
30+
31+ @abstractmethod
32+ def format_prompt (self , prompt : str ) -> str :
33+ pass
34+
35+
36+ class StarCoderConfig (InferenceConfig ):
37+
38+ def __init__ (self , prompted : bool = False ):
39+ super ().__init__ (prompted = prompted )
40+
41+ def get_dtype (self ):
42+ return torch .float16
43+
44+ def init_padding (self , tokenizer ):
45+ tokenizer .pad_token_id = tokenizer .eos_token_id # for batching
46+ tokenizer .padding_side = "left" # for decoder-only models
47+
48+ def get_pad_token_id (self , tokenizer ) -> int :
49+ return tokenizer .eos_token_id
50+
51+ def get_eos_token_id (self , tokenizer ) -> int :
52+ return None
53+
54+ def format_prompt (self , prompt : str ) -> str :
55+ if self .prompted :
56+ return f"<filename>solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
57+ return prompt .strip ()
58+
59+ class CodeLlamaConfig (InferenceConfig ):
60+
61+ def __init__ (self , prompted : bool = False ):
62+ super ().__init__ (prompted = prompted )
63+
64+ def get_dtype (self ):
65+ return torch .float16
66+
67+ def init_padding (self , tokenizer ):
68+ tokenizer .pad_token_id = tokenizer .eos_token_id # for batching
69+ tokenizer .padding_side = "left" # for decoder-only models
70+ pass
71+
72+ def get_pad_token_id (self , tokenizer ) -> int :
73+ return tokenizer .pad_token_id
74+
75+ def get_eos_token_id (self , tokenizer ) -> int :
76+ return tokenizer .eos_token_id
77+
78+ def format_prompt (self , prompt : str ) -> str :
79+ if self .prompted :
80+ return f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
81+ return prompt .strip ()
82+
83+
84+ class PolyCoderConfig (InferenceConfig ):
85+
86+ def __init__ (self , prompted : bool = False ):
87+ super ().__init__ (prompted = prompted )
88+
89+ def get_dtype (self ):
90+ return torch .float16
91+
92+ def init_padding (self , tokenizer ):
93+ tokenizer .pad_token_id = tokenizer .eos_token_id # for batching
94+ tokenizer .padding_side = "left" # for decoder-only models
95+
96+ def get_pad_token_id (self , tokenizer ) -> int :
97+ return tokenizer .eos_token_id
98+
99+ def get_eos_token_id (self , tokenizer ) -> int :
100+ return tokenizer .eos_token_id
101+
102+ def format_prompt (self , prompt : str ) -> str :
103+ if self .prompted :
104+ return f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
105+ return prompt .strip ()
106+
107+
108+ class PhindConfig (InferenceConfig ):
109+
110+ def __init__ (self , prompted : bool = False ):
111+ super ().__init__ (prompted = prompted )
112+
113+ def get_dtype (self ):
114+ return torch .float16
115+
116+ def init_padding (self , tokenizer ):
117+ tokenizer .pad_token_id = tokenizer .eos_token_id # for batching
118+ tokenizer .padding_side = "left" # for decoder-only models
119+
120+ def get_pad_token_id (self , tokenizer ) -> int :
121+ return tokenizer .eos_token_id
122+
123+ def get_eos_token_id (self , tokenizer ) -> int :
124+ return tokenizer .eos_token_id
125+
126+ def format_prompt (self , prompt : str ) -> str :
127+ if self .prompted :
128+ return f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
129+ return prompt .strip ()
130+
131+
132+ def get_inference_config (model_name : str , ** kwargs ) -> InferenceConfig :
133+ if model_name == "bigcode/starcoderbase" :
134+ return StarCoderConfig (** kwargs )
135+ elif model_name .startswith ("codellama/CodeLlama-" ) and 'Instruct' not in model_name :
136+ return CodeLlamaConfig (** kwargs )
137+ elif model_name == "NinedayWang/PolyCoder-2.7B" :
138+ return PolyCoderConfig (** kwargs )
139+ elif model_name == 'Phind/Phind-CodeLlama-34B-v2' :
140+ return PhindConfig (** kwargs )
141+ else :
142+ raise ValueError (f"Unknown model name: { model_name } " )
143+
144+
145+ def clean_output (output : str , prompt : str ) -> str :
146+ """ Remove `prompt` from the begging of `output`.
147+ Also truncate at the end of the function definition (i.e. matching closing brace).
148+ """
149+ # replace up to the end of the first instance of prompt
150+ prompt_loc = output .find (prompt )
151+ if prompt_loc == - 1 :
152+ raise ValueError (f"Prompt not found in output: { prompt } " )
153+ output = output [prompt_loc + len (prompt ):].strip ()
154+
155+ # temporarily add opening brace to the beginning
156+ output = '{' + output
157+
158+ # find the matching brace to output[0]
159+ stack = []
160+ index = 0
161+ while index < len (output ):
162+ token = output [index ]
163+ if token == '{' :
164+ stack .append (token )
165+ elif token == '}' :
166+ stack .pop ()
167+ if len (stack ) == 0 :
168+ break
169+
170+ index += 1
171+
172+ # truncate at the matching brace
173+ output = output [1 :index + 1 ]
174+ return output
175+
176+ class PromptDataset (Dataset ):
177+ ''' PyTorch dataset that simply wraps a list of strings. They do not have to have the same length.
178+ '''
179+
180+ def __init__ (self , prompts ):
181+ super ().__init__ ()
182+ self .prompts_ = prompts
183+
184+ def __len__ (self ):
185+ return len (self .prompts_ )
186+
187+ def __getitem__ (self , idx ):
188+ return self .prompts_ [idx ]
189+
190+
191+ def has_balanced_brackets (text : str , left_bracket : str = '{' , right_bracket : str = '}' ) -> bool :
192+ ''' Check if string has balanced brackets.
193+ modified from: https://stackoverflow.com/a/38834249/3769237
194+
195+ Arguments:
196+ text: string to check for balanced brackets in.
197+ left_bracket: left bracket to balance
198+ right_bracket: right bracket to balance
199+
200+ Returns:
201+ true if left_bracket and right_bracket are balanced
202+ '''
203+ stack = []
204+ balanced = True
205+ index = 0
206+ while index < len (text ) and balanced :
207+ token = text [index ]
208+ if token == left_bracket :
209+ stack .append (token )
210+ elif token == right_bracket :
211+ if len (stack ) == 0 :
212+ balanced = False
213+ else :
214+ stack .pop ()
215+
216+ index += 1
217+
218+ return balanced and len (stack ) == 0
219+
220+
221+ class BalancedBracketsCriteria (StoppingCriteria ):
222+ ''' extension of transformers' text-generation stopping criteria.
223+ Stops either when function is complete (i.e. { and } are balanced) or when max_length is surpassed, whichever
224+ happens first.
225+
226+ _Note:_ This is a slow stopping criteria, but it's much faster than continually running model inference when
227+ it does not need to be run anymore.
228+ '''
229+
230+ def __init__ (self , max_length : int , tokenizer , left_bracket : str = '{' , right_bracket : str = '}' ):
231+ self .max_length = max_length
232+ self .tokenizer = tokenizer
233+ self .left_bracket = left_bracket
234+ self .right_bracket = right_bracket
235+
236+ def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor , ** kwargs ) -> bool :
237+ if input_ids .shape [- 1 ] > self .max_length :
238+ # already too long, early stop
239+ return True
240+
241+ # return true if {} are balanced i.e. the function is complete
242+ return all (
243+ has_balanced_brackets (
244+ self .tokenizer .decode (t ),
245+ left_bracket = self .left_bracket ,
246+ right_bracket = self .right_bracket
247+ ) for t in input_ids )
0 commit comments