We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 7192a65 + b472089 commit 5819a4aCopy full SHA for 5819a4a
3 files changed
ML-Frameworks/pytorch-aarch64/examples/README.md
@@ -201,7 +201,7 @@ The script [torchchat_llm_text_gen.py](torchchat_llm_text_gen.py) demonstrates h
201
To run infernece using torchchat call:
202
203
```
204
-LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libtcmalloc.so.4 TORCHINDUCTOR_CPP_WRAPPER=1 TORCHINDUCTOR_FREEZING=1 OMP_NUM_THREADS=16 python torchchat_llm_text_gen.py
+LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libtcmalloc.so.4 TORCHINDUCTOR_CPP_WRAPPER=1 TORCHINDUCTOR_FREEZING=1 OMP_NUM_THREADS=16 python torchchat_llm_text_gen.py --compile
205
206
207
#### Command-Line Options
@@ -212,6 +212,9 @@ LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libtcmalloc.so.4 TORCHINDUCTOR_CPP_WRAPPE
212
`--max-new-tokens`
213
Description: Max new tokens to generate.
214
215
+`--compile`
216
+ Description: Whether to compile the model (default: `False`).
217
+
218
`--model`
219
Description: Model alias. (Default: `"llama2"` )
220
@@ -224,7 +227,7 @@ The script [transformers_llm_text_gen.py](transformers_llm_text_gen.py) demonstr
224
227
225
228
226
229
-LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libtcmalloc.so.4 TORCHINDUCTOR_CPP_WRAPPER=1 TORCHINDUCTOR_FREEZING=1 OMP_NUM_THREADS=16 python transformers_llm_text_gen.py
230
+LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libtcmalloc.so.4 TORCHINDUCTOR_CPP_WRAPPER=1 TORCHINDUCTOR_FREEZING=1 OMP_NUM_THREADS=16 python transformers_llm_text_gen.py --compile
231
232
233
@@ -235,6 +238,9 @@ LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libtcmalloc.so.4 TORCHINDUCTOR_CPP_WRAPPE
235
238
236
239
237
240
241
242
243
244
245
Description: Local Path to model repo or huggingface model id. (Default: `"meta-llama/Llama-2-7b-hf"` )
246
ML-Frameworks/pytorch-aarch64/examples/torchchat_llm_text_gen.py
@@ -31,6 +31,8 @@ def main(args):
31
"python3", torchchat_path, "generate", args.model,
32
"--quantize", str(args.quant_config),
33
"--prompt", prompt,
34
+ "--compile" if args.compile else "",
35
+ "--compile-prefill" if args.compile else "",
36
"--max-autotune", "--max-new-tokens", str(args.max_new_tokens)
37
]
38
command = [arg for arg in command if arg]
@@ -45,6 +47,8 @@ def main(args):
45
47
help='Path to json file for quantization config')
46
48
parser.add_argument('--max-new-tokens', type=int,
49
default=64, help='New tokens to generate at decode.')
50
+ parser.add_argument('--compile', action='store_true',
51
+ help='Whether to compile the model.')
52
parser.add_argument('--model', type=str, default="llama2",
53
help='Torchchat supported model alias')
54
parser.add_argument('--prompt', type=str, default="In a distant world where magic and technology coexist, "
ML-Frameworks/pytorch-aarch64/examples/transformers_llm_text_gen.py
@@ -130,6 +130,10 @@ def get_quantized_model(args):
130
print("Quantizing model to 4 bit ..")
131
quantize_model(model, "cpu", args.quant_config)
132
model = model.eval()
133
+ if args.compile:
134
+ model.generation_config.cache_implementation = "static"
135
+ model.forward = torch.compile(
136
+ model.forward, backend='inductor', dynamic=True, fullgraph=True)
137
return model, tokenizer, config
138
139
@@ -193,6 +197,8 @@ def main(args):
193
197
"gen_ai_utils/quant_configs/aarch64_cpu_channelwise.json", help='Path to json file for quantization config')
194
198
195
199
200
196
parser.add_argument('--model', type=Path, default=Path("meta-llama/Llama-2-7b-hf"),
help='Hugging Face model ID or Cloned model repository with model files')
0 commit comments