Skip to content

Commit 5819a4a

Browse files
authored
Merge pull request #326 from karmeh01/add-compile-option
Revert "Removes compile option from genai PyTorch examples"
2 parents 7192a65 + b472089 commit 5819a4a

3 files changed

Lines changed: 18 additions & 2 deletions

File tree

ML-Frameworks/pytorch-aarch64/examples/README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ The script [torchchat_llm_text_gen.py](torchchat_llm_text_gen.py) demonstrates h
201201
To run infernece using torchchat call:
202202

203203
```
204-
LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libtcmalloc.so.4 TORCHINDUCTOR_CPP_WRAPPER=1 TORCHINDUCTOR_FREEZING=1 OMP_NUM_THREADS=16 python torchchat_llm_text_gen.py
204+
LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libtcmalloc.so.4 TORCHINDUCTOR_CPP_WRAPPER=1 TORCHINDUCTOR_FREEZING=1 OMP_NUM_THREADS=16 python torchchat_llm_text_gen.py --compile
205205
```
206206

207207
#### Command-Line Options
@@ -212,6 +212,9 @@ LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libtcmalloc.so.4 TORCHINDUCTOR_CPP_WRAPPE
212212
`--max-new-tokens`
213213
Description: Max new tokens to generate.
214214

215+
`--compile`
216+
Description: Whether to compile the model (default: `False`).
217+
215218
`--model`
216219
Description: Model alias. (Default: `"llama2"` )
217220

@@ -224,7 +227,7 @@ The script [transformers_llm_text_gen.py](transformers_llm_text_gen.py) demonstr
224227
To run infernece using torchchat call:
225228

226229
```
227-
LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libtcmalloc.so.4 TORCHINDUCTOR_CPP_WRAPPER=1 TORCHINDUCTOR_FREEZING=1 OMP_NUM_THREADS=16 python transformers_llm_text_gen.py
230+
LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libtcmalloc.so.4 TORCHINDUCTOR_CPP_WRAPPER=1 TORCHINDUCTOR_FREEZING=1 OMP_NUM_THREADS=16 python transformers_llm_text_gen.py --compile
228231
```
229232

230233
#### Command-Line Options
@@ -235,6 +238,9 @@ LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libtcmalloc.so.4 TORCHINDUCTOR_CPP_WRAPPE
235238
`--max-new-tokens`
236239
Description: Max new tokens to generate.
237240

241+
`--compile`
242+
Description: Whether to compile the model (default: `False`).
243+
238244
`--model`
239245
Description: Local Path to model repo or huggingface model id. (Default: `"meta-llama/Llama-2-7b-hf"` )
240246

ML-Frameworks/pytorch-aarch64/examples/torchchat_llm_text_gen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def main(args):
3131
"python3", torchchat_path, "generate", args.model,
3232
"--quantize", str(args.quant_config),
3333
"--prompt", prompt,
34+
"--compile" if args.compile else "",
35+
"--compile-prefill" if args.compile else "",
3436
"--max-autotune", "--max-new-tokens", str(args.max_new_tokens)
3537
]
3638
command = [arg for arg in command if arg]
@@ -45,6 +47,8 @@ def main(args):
4547
help='Path to json file for quantization config')
4648
parser.add_argument('--max-new-tokens', type=int,
4749
default=64, help='New tokens to generate at decode.')
50+
parser.add_argument('--compile', action='store_true',
51+
help='Whether to compile the model.')
4852
parser.add_argument('--model', type=str, default="llama2",
4953
help='Torchchat supported model alias')
5054
parser.add_argument('--prompt', type=str, default="In a distant world where magic and technology coexist, "

ML-Frameworks/pytorch-aarch64/examples/transformers_llm_text_gen.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ def get_quantized_model(args):
130130
print("Quantizing model to 4 bit ..")
131131
quantize_model(model, "cpu", args.quant_config)
132132
model = model.eval()
133+
if args.compile:
134+
model.generation_config.cache_implementation = "static"
135+
model.forward = torch.compile(
136+
model.forward, backend='inductor', dynamic=True, fullgraph=True)
133137
return model, tokenizer, config
134138

135139

@@ -193,6 +197,8 @@ def main(args):
193197
"gen_ai_utils/quant_configs/aarch64_cpu_channelwise.json", help='Path to json file for quantization config')
194198
parser.add_argument('--max-new-tokens', type=int,
195199
default=64, help='New tokens to generate at decode.')
200+
parser.add_argument('--compile', action='store_true',
201+
help='Whether to compile the model.')
196202
parser.add_argument('--model', type=Path, default=Path("meta-llama/Llama-2-7b-hf"),
197203
help='Hugging Face model ID or Cloned model repository with model files')
198204
parser.add_argument('--prompt', type=str, default="In a distant world where magic and technology coexist, "

0 commit comments

Comments
 (0)