-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun_baseline_gsm8k_comprehensive.pbs
More file actions
174 lines (141 loc) · 5.67 KB
/
run_baseline_gsm8k_comprehensive.pbs
File metadata and controls
174 lines (141 loc) · 5.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#!/bin/bash
#PBS -N UAG_Baseline_GSM8K_Comprehensive
#PBS -l walltime=12:00:00
#PBS -l select=1:ncpus=8:mem=64gb:ngpus=1
#PBS -j oe
#PBS -o baseline_gsm8k_comprehensive.log
echo "=== UAG Baseline GSM8K (Comprehensive Evaluation) ==="
echo "Start time: $(date)"
echo "Working directory: $(pwd)"
# Environment variables
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export http_proxy=http://127.0.0.1:7890
export https_proxy=http://127.0.0.1:7890
export MKL_INTERFACE_LAYER=LP64,ILP64
export HF_HOME=$HOME/.cache/huggingface
export TRANSFORMERS_CACHE=$HOME/.cache/transformers
# Activate conda environment
source $HOME/anaconda3/bin/activate uag
# Check GPU
echo "=== Check GPU ==="
nvidia-smi
# Change to working directory
cd $PWD
echo "=== Start GSM8K Baseline (Comprehensive) ==="
echo "Dataset: GSM8K (1319 samples)"
echo "Model: Mistral-7B-Instruct-v0.3"
echo "Method: Standard Chain-of-Thought (without UAG)"
echo "Evaluation: Comprehensive metrics (Exact Match, F1, BLEU, ROUGE-L, etc.)"
# Run baseline
python -c "
import sys
import os
import json
import torch
import re
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from code.comprehensive_metrics import ComprehensiveMetrics
print('=== Load model ===')
model_name = 'mistralai/Mistral-7B-Instruct-v0.3'
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map='auto',
local_files_only=True
)
print('=== Load metrics ===')
metrics = ComprehensiveMetrics()
print('=== Load data ===')
with open('GSM8K_input.jsonl', 'r', encoding='utf-8') as f:
data = [json.loads(line) for line in f]
print(f'Loaded {len(data)} samples')
print('=== Start baseline ===')
results = []
comprehensive_results = []
for i, item in enumerate(data):
if i % 100 == 0:
print(f'Progress: {i}/{len(data)}')
question = item['question']
answer = item['answer']
# Build prompt
prompt = f'Q: {question}\\nA: Let\'s think step by step.'
# Generate response
inputs = tokenizer(prompt, return_tensors='pt')
if torch.cuda.is_available():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
inputs['input_ids'],
max_length=1024,
temperature=0.7,
do_sample=True,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
# Extract numeric answer
numbers = re.findall(r'\\d+', response)
pred = numbers[-1] if numbers else ''
# Extract final numeric answer (after '####') from ground truth
answer_match = re.search(r'#### (\\d+)', answer)
true_answer = answer_match.group(1) if answer_match else answer
# Compute basic accuracy flag
is_correct = 1 if pred == true_answer else 0
# Compute comprehensive metrics
comprehensive_result = metrics.evaluate_sample(
prediction=pred,
ground_truth=answer,
reasoning=response
)
results.append({
'question': question,
'answer': answer,
'prediction': pred,
'response': response,
'correct': is_correct,
'comprehensive_metrics': comprehensive_result
})
comprehensive_results.append(comprehensive_result)
# Compute overall metrics
total_correct = sum(r['correct'] for r in results)
accuracy = total_correct / len(data) * 100
# Compute aggregate comprehensive metrics
aggregate_metrics = metrics.evaluate_batch(comprehensive_results)
print(f'\\n=== Baseline results ===')
print(f'Total samples: {len(data)}')
print(f'Num correct: {total_correct}')
print(f'Accuracy: {accuracy:.2f}%')
print(f'\\n=== Comprehensive results ===')
print(f'Exact Match Accuracy: {aggregate_metrics.get(\"exact_match_accuracy\", 0):.4f}')
print(f'F1 Score: {aggregate_metrics.get(\"f1_score_mean\", 0):.4f} ± {aggregate_metrics.get(\"f1_score_std\", 0):.4f}')
print(f'BLEU Score: {aggregate_metrics.get(\"bleu_score_mean\", 0):.4f} ± {aggregate_metrics.get(\"bleu_score_std\", 0):.4f}')
print(f'ROUGE-L: {aggregate_metrics.get(\"rouge_l_mean\", 0):.4f} ± {aggregate_metrics.get(\"rouge_l_std\", 0):.4f}')
# Print reasoning quality metrics
if 'reasoning_step_indicator_ratio_mean' in aggregate_metrics:
print(f'\\n=== Reasoning quality metrics ===')
print(f'Step Indicator Ratio: {aggregate_metrics.get(\"reasoning_step_indicator_ratio_mean\", 0):.4f}')
print(f'Math Indicator Ratio: {aggregate_metrics.get(\"reasoning_math_indicator_ratio_mean\", 0):.4f}')
print(f'Logic Indicator Ratio: {aggregate_metrics.get(\"reasoning_logic_indicator_ratio_mean\", 0):.4f}')
print(f'Avg Sentence Length: {aggregate_metrics.get(\"reasoning_avg_sentence_length_mean\", 0):.2f}')
# Save results
os.makedirs('results', exist_ok=True)
with open('results/GSM8K_baseline_results_comprehensive.jsonl', 'w', encoding='utf-8') as f:
for result in results:
f.write(json.dumps(result, ensure_ascii=False) + '\\n')
# Save metrics summary
with open('results/GSM8K_baseline_metrics_summary.json', 'w', encoding='utf-8') as f:
json.dump({
'task': 'GSM8K',
'method': 'baseline',
'total_samples': len(data),
'exact_match_accuracy': accuracy / 100,
'comprehensive_metrics': aggregate_metrics
}, f, indent=2, ensure_ascii=False)
print('\\nSaved to:')
print('- results/GSM8K_baseline_results_comprehensive.jsonl')
print('- results/GSM8K_baseline_metrics_summary.json')
"
echo "=== GSM8K baseline completed ==="
echo "End time: $(date)"