|
10 | 10 | parser.add_argument('--output', help='Path to the output JSON file') |
11 | 11 | parser.add_argument('--max_new_tokens', type=int, default=500, help='Maximum number of new tokens to generate (default: 500)') |
12 | 12 | parser.add_argument('--max_length', type=int, default=100, help='Maximum length (default: 500)') |
13 | | -parser.add_argument('--num-samples-per-prompt', type=int, default=100, help='Number of code samples to generate (default: 100)') |
| 13 | +parser.add_argument('--num_samples_per_prompt', type=int, default=100, help='Number of code samples to generate (default: 100)') |
14 | 14 | parser.add_argument('--temperature', type=float, default=0.9, help='Temperature for controlling randomness (default: 0.9)') |
15 | 15 | parser.add_argument('--top_p', type=float, default=0.9, help='Top p value for nucleus sampling (default: 0.9)') |
16 | 16 | parser.add_argument('--do_sample', action='store_true', help='Enable sampling (default: False)') |
|
21 | 21 |
|
22 | 22 | generator = pipeline(model=args.model, torch_dtype=torch.float16, device_map="auto") |
23 | 23 |
|
24 | | -responses = {} |
| 24 | +responses = [] |
25 | 25 |
|
26 | 26 | for item in prompts: |
| 27 | + answer = {} |
| 28 | + problem_type = item["problem_type"] |
| 29 | + language = item["language"] |
27 | 30 | name = item["name"] |
| 31 | + paral_model = item["parallelism_model"] |
28 | 32 | prompt = item["prompt"] |
29 | 33 |
|
30 | | - print(f"Generating code for: {name}") |
31 | | - |
32 | | - response = generator(prompt, max_length=args.max_length, max_new_tokens=args.max_new_tokens, do_sample=args.do_sample, temperature=args.temperature, top_p=args.top_p) |
| 34 | + answer["problem_type"] = problem_type |
| 35 | + answer["language"] = language |
| 36 | + answer["name"] = name |
| 37 | + answer["parallelism_model"] = paral_model |
| 38 | + answer["prompt"] = prompt |
| 39 | + output = [] |
33 | 40 |
|
34 | | - responses[name] = response[0]['generated_text'] |
| 41 | + for x in range(args.num_samples_per_prompt): |
| 42 | + print(f"Generating code for: {name}") |
| 43 | + |
| 44 | + response = generator(prompt, max_length=args.max_length, max_new_tokens=args.max_new_tokens, do_sample=args.do_sample, temperature=args.temperature, top_p=args.top_p) |
| 45 | + |
| 46 | + output.append(response[0]['generated_text'].replace(prompt, "").strip()) |
| 47 | + |
| 48 | + answer["output"] = output |
| 49 | + |
| 50 | + responses.append(answer) |
35 | 51 |
|
36 | 52 | if args.output: |
37 | 53 | with open(args.output, 'w') as output_file: |
|
0 commit comments