Skip to content

Commit cd2d5be

Browse files
authored
ensure efficient kascade is only used with fp16 datatype
1 parent 96a3e49 commit cd2d5be

3 files changed

Lines changed: 7 additions & 4 deletions

File tree

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<p align="center">
22
<picture>
3-
<img alt="Kascade" src="assets/logo_kascade.png" height="20%" width="20%">
3+
<img alt="Kascade" src="assets/logo_kascade.png" height="25%" width="25%">
44
</picture>
55
</p>
66

@@ -72,6 +72,7 @@ python scripts/eval_script.py --model_name meta-llama/Meta-Llama-3.1-8B-Instruct
7272
```
7373

7474
**NOTE:** Currently for `efficient_kascade` which uses our efficient kernels only tile_size 32 is supported.
75+
**NOTE:** Currently for `efficient_kascade` only fp16 is supported. Currently models are loaded in same type as mentioned in their model config. To be able to run `efficient_kascade` please go to line 17 in `src/model_utils.py` and change `torch_dtype=torch.float16` when loading the model.
7576

7677
**NOTE**: To use multiple GPUs you can run `scripts/eval_script.py` with `accelerate launch` and take advantage of DDP for faster processing of queries. If you run into errors, fallback to single gpu runs.
7778

scripts/eval_script.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from datasets import load_dataset
1313
from transformers import set_seed
1414
from transformers.utils import is_flash_attn_2_available, is_flash_attn_3_available
15+
import torch
1516

1617
def main():
1718
# Parse the arguments
@@ -93,6 +94,9 @@ def main():
9394

9495
# Loop: models -> strategies -> subsets
9596
for strategy_name in args.strategies:
97+
if strategy_name == "efficient_kascade" and model.config.dtype != torch.float16:
98+
raise ValueError("Efficient Kascade strategy requires model to be in float16 precision. Please go to line 17 in src/model_utils.py and change torch_dtype=torch.float16 when loading the model for running with efficient_kascade.")
99+
96100
set_seed(args.seed) # Ensure reproducibility per run
97101

98102
# Create strategy

src/model_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def get_tokenizer_and_model(model_name, attn_implementation, device):
1414
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
1515

1616
model_kwargs = {
17-
"torch_dtype": torch.float16,
17+
"torch_dtype": "auto",
1818
"attn_implementation": attn_implementation,
1919
"cache_dir": "/dev/shm",
2020
"pretrained_model_name_or_path": model_name,
@@ -33,8 +33,6 @@ def get_tokenizer_and_model(model_name, attn_implementation, device):
3333
single_gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
3434
if model_size_gb * 2 > single_gpu_mem_gb: # float16 is 2 bytes per parameter
3535
model_kwargs["device_map"] = "auto"
36-
if model_size_gb > 8:
37-
model_kwargs["torch_dtype"] = "auto"
3836
model = AutoModelForCausalLM.from_pretrained(**model_kwargs)
3937

4038
model.eval()

0 commit comments

Comments
 (0)