-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun_uag_gsm8k_fixed.pbs
More file actions
79 lines (67 loc) · 1.96 KB
/
run_uag_gsm8k_fixed.pbs
File metadata and controls
79 lines (67 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/bin/bash
#PBS -N UAG_GSM8K_Fixed
#PBS -l select=1:ncpus=4:ngpus=1:mem=32GB
#PBS -l walltime=12:00:00
#PBS -q v1_gpu72
#PBS -j oe
echo "=== UAG GSM8K (Fixed) ==="
echo "Start time: $(date)"
echo "Working directory: $PWD"
# Environment variables
export HF_HOME=$HOME/.cache/huggingface
export TRANSFORMERS_CACHE=$HOME/.cache/huggingface
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export MKL_INTERFACE_LAYER=LP64
# Check GPU
echo "=== Check GPU ==="
nvidia-smi
# Activate conda environment
echo "=== Activate conda environment ==="
source $HOME/anaconda3/etc/profile.d/conda.sh
conda activate uag
# Change to working directory
cd $PWD
echo "=== Start UAG GSM8K (Fixed) ==="
echo "Dataset: GSM8K (1319 samples)"
echo "Model: Mistral-7B-Instruct-v0.3"
echo "Method: True UAG (Uncertainty-aware Adaptive Guidance)"
echo "Includes: Uncertainty computation, demo clustering, adaptive reasoning"
echo "Fixes: All import and data format issues"
# Run fixed UAG experiment
python -c "
import sys
import os
sys.path.append(os.getcwd())
os.chdir(os.getcwd())
from code.uag import UAGFixed
import argparse
# Set all required UAG parameters
args = argparse.Namespace()
args.task = 'GSM8K'
args.data_path = 'GSM8K_input.jsonl'
args.record_path = 'results/GSM8K_uag_results_fixed.jsonl'
args.demonstration_path = 'GSM8K_input.jsonl'
args.demonstration_number = 16
args.theta = 16
args.lambda1 = 0.5
args.lambda2 = 0.5
args.k = 8
args.temperature = 0.5
args.max_length = 8192
args.max_new_tokens = 2048
args.max_loop = 10
args.device = 'cuda'
print(f'Starting fixed UAG experiment...')
print(f'Task: {args.task}')
print(f'Data path: {args.data_path}')
print(f'Theta: {args.theta}')
print(f'Lambda1: {args.lambda1}, Lambda2: {args.lambda2}')
print(f'K: {args.k}')
uag = UAGFixed(args)
accuracy = uag.process()
print(f'\n=== Completed! ===')
print(f'GSM8K UAG accuracy: {accuracy:.4f}')
print(f'Results saved to: {args.record_path}')
"
echo "=== UAG GSM8K completed ==="
echo "End time: $(date)"