-
Notifications
You must be signed in to change notification settings - Fork 62
Expand file tree
/
Copy pathbase_config.py
More file actions
128 lines (100 loc) · 6.01 KB
/
base_config.py
File metadata and controls
128 lines (100 loc) · 6.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../scripts"))
from libinfinicore_infer import DeviceType
class BaseConfig:
"""InfiniLM Unified Config - Command line argument parser"""
def __init__(self):
self.parser = argparse.ArgumentParser(description="InfiniLM Unified Config")
self._add_common_args()
self.args, self.extra = self.parser.parse_known_args()
self.model = self.args.model
self.device_name = self.args.device
self.device_type = self._get_device_type(self.args.device)
self.tp = self.args.tp
self.attn = self.args.attn
self.enable_graph = self.args.enable_graph
self.cache_type = self.args.cache_type
self.enable_paged_attn = self.args.enable_paged_attn
self.paged_kv_block_size = self.args.paged_kv_block_size
self.kv_cache_dtype = self.args.kv_cache_dtype
self.skip_load = self.args.skip_load
self.batch_size = self.args.batch_size
self.input_len = self.args.input_len
self.output_len = self.args.output_len
self.max_new_tokens = self.args.max_new_tokens
self.top_k = self.args.top_k
self.top_p = self.args.top_p
self.temperature = self.args.temperature
self.warm_up = self.args.warm_up
self.verbose = self.args.verbose
self.log_evel = self.args.log_evel
# Evaluation parameters
self.bench = self.args.bench
self.backend = self.args.backend
self.ndev = self.args.ndev
self.subject = self.args.subject
self.split = self.args.split
self.num_samples = self.args.num_samples
self.output_csv = self.args.output_csv
self.cache_dir = self.args.cache_dir
# Quantization parameters
self.awq = self.args.awq
self.gptq = self.args.gptq
def _add_common_args(self):
# --- base configuration ---
self.parser.add_argument("--model", type=str, required=True)
self.parser.add_argument("--device", type=str, default="cpu")
self.parser.add_argument("--tp", "--tensor-parallel-size", type=int, default=1)
# --- Infer backend optimization ---
self.parser.add_argument("--attn", type=str, default="default", choices=["default", "paged-attn", "flash-attn"])
self.parser.add_argument("--enable-graph", action="store_true")
self.parser.add_argument("--cache-type", type=str, default="paged", choices=["paged", "static"])
self.parser.add_argument("--enable-paged-attn", action="store_true", help="use paged cache",)
self.parser.add_argument("--paged-kv-block-size", type=int, default=256)
self.parser.add_argument("--kv-cache-dtype", type=str, default=None, choices=["int8"], help="KV cache data type")
self.parser.add_argument("--skip-load", action="store_true", help="skip loading model weights")
# --- Length and infer parameters ---
self.parser.add_argument("--batch-size", type=int, default=1)
self.parser.add_argument("--input-len", type=int, default=10, help="input sequence length")
self.parser.add_argument("--output-len", type=int, default=20, help="output sequence length")
self.parser.add_argument("--max-new-tokens", type=int, default=500, help="maximum number of new tokens to generate")
self.parser.add_argument("--top-k", type=int, default=1)
self.parser.add_argument("--top-p", type=float, default=1.0)
self.parser.add_argument("--temperature", type=float, default=1.0)
# --- debug ---
self.parser.add_argument("--warmup", action="store_true")
self.parser.add_argument("--verbose", action="store_true")
self.parser.add_argument("--log-evel", type=str, default="INFO")
# --- Evaluation parameters ---
self.parser.add_argument("--bench", type=str, default=None, choices=["ceval", "mmlu"], help="benchmark to evaluate")
self.parser.add_argument("--backend", type=str, default="cpp", choices=["python", "cpp", "torch", "vllm"], help="backend type")
self.parser.add_argument("--ndev", type=int, default=1, help="number of devices for tensor parallelism")
self.parser.add_argument("--subject", type=str, default="all", help="subject(s) to evaluate, comma-separated or 'all'")
self.parser.add_argument("--split", type=str, default="test", choices=["test", "val", "all"], help="dataset split to use")
self.parser.add_argument("--num-samples", type=int, default=None, help="number of samples to evaluate per subject")
self.parser.add_argument("--output-csv", type=str, default=None, help="path to output CSV file for results")
self.parser.add_argument("--cache-dir", type=str, default=None, help="directory for dataset cache")
# --- Quantization parameters ---
self.parser.add_argument("--awq", action="store_true", help="use AWQ quantization")
self.parser.add_argument("--gptq", action="store_true", help="use GPTQ quantization")
def _get_device_type(self, dev_str):
"""Convert device string to DeviceType enum"""
DEVICE_TYPE_MAP = {
"cpu": DeviceType.DEVICE_TYPE_CPU,
"nvidia": DeviceType.DEVICE_TYPE_NVIDIA,
"qy": DeviceType.DEVICE_TYPE_QY,
"cambricon": DeviceType.DEVICE_TYPE_CAMBRICON,
"ascend": DeviceType.DEVICE_TYPE_ASCEND,
"metax": DeviceType.DEVICE_TYPE_METAX,
"moore": DeviceType.DEVICE_TYPE_MOORE,
"iluvatar": DeviceType.DEVICE_TYPE_ILUVATAR,
"kunlun": DeviceType.DEVICE_TYPE_KUNLUN,
"hygon": DeviceType.DEVICE_TYPE_HYGON,
"ali": DeviceType.DEVICE_TYPE_ALI
}
return DEVICE_TYPE_MAP.get(dev_str.lower(), DeviceType.DEVICE_TYPE_CPU)
def __repr__(self):
"""String representation of configuration"""
return f"BaseConfig(model='{self.model}', device='{self.device_name}', tp={self.tp})"