Skip to content

Commit 0e5e5b5

Browse files
ZJShownZhaojun XieDando18
authored
Generate updates (#8)
* updated generate.py --------- Co-authored-by: Zhaojun Xie <zxie12@login-1.zaratan.umd.edu> Co-authored-by: Daniel Nichols <dando18studios@gmail.com>
1 parent 0a25b14 commit 0e5e5b5

1 file changed

Lines changed: 22 additions & 6 deletions

File tree

generate.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
parser.add_argument('--output', help='Path to the output JSON file')
1111
parser.add_argument('--max_new_tokens', type=int, default=500, help='Maximum number of new tokens to generate (default: 500)')
1212
parser.add_argument('--max_length', type=int, default=100, help='Maximum length (default: 500)')
13-
parser.add_argument('--num-samples-per-prompt', type=int, default=100, help='Number of code samples to generate (default: 100)')
13+
parser.add_argument('--num_samples_per_prompt', type=int, default=100, help='Number of code samples to generate (default: 100)')
1414
parser.add_argument('--temperature', type=float, default=0.9, help='Temperature for controlling randomness (default: 0.9)')
1515
parser.add_argument('--top_p', type=float, default=0.9, help='Top p value for nucleus sampling (default: 0.9)')
1616
parser.add_argument('--do_sample', action='store_true', help='Enable sampling (default: False)')
@@ -21,17 +21,33 @@
2121

2222
generator = pipeline(model=args.model, torch_dtype=torch.float16, device_map="auto")
2323

24-
responses = {}
24+
responses = []
2525

2626
for item in prompts:
27+
answer = {}
28+
problem_type = item["problem_type"]
29+
language = item["language"]
2730
name = item["name"]
31+
paral_model = item["parallelism_model"]
2832
prompt = item["prompt"]
2933

30-
print(f"Generating code for: {name}")
31-
32-
response = generator(prompt, max_length=args.max_length, max_new_tokens=args.max_new_tokens, do_sample=args.do_sample, temperature=args.temperature, top_p=args.top_p)
34+
answer["problem_type"] = problem_type
35+
answer["language"] = language
36+
answer["name"] = name
37+
answer["parallelism_model"] = paral_model
38+
answer["prompt"] = prompt
39+
output = []
3340

34-
responses[name] = response[0]['generated_text']
41+
for x in range(args.num_samples_per_prompt):
42+
print(f"Generating code for: {name}")
43+
44+
response = generator(prompt, max_length=args.max_length, max_new_tokens=args.max_new_tokens, do_sample=args.do_sample, temperature=args.temperature, top_p=args.top_p)
45+
46+
output.append(response[0]['generated_text'].replace(prompt, "").strip())
47+
48+
answer["output"] = output
49+
50+
responses.append(answer)
3551

3652
if args.output:
3753
with open(args.output, 'w') as output_file:

0 commit comments

Comments
 (0)