Skip to content

Latest commit

 

History

History
64 lines (40 loc) · 1.83 KB

File metadata and controls

64 lines (40 loc) · 1.83 KB

Fused Flash Decode Attention

This is an example for a distributed Flash Decode kernel designed to accelerate LLM Inference. Part of the code is adapted from Triton-distributed.

This is a novel implementation that fuses communication and computation, diminshing the collective kernel launch latencies and the associated waits.

The core layer implementation is in examples/13_flash_decode/flash_decode_fused_layer.py while the Triton fused kernels are defined in examples/13_flash_decode/decode_kernels.py.

We perform comparisons against the RCCL baseline.


Usage

Simple Example

To simply do a test run of the code, run:

torchrun --nproc_per_node=<num_gpus> --standalone examples/13_flash_decode/example.py

Pass --validate to verify the output against a PyTorch reference:

torchrun --nproc_per_node=<num_gpus> --standalone examples/13_flash_decode/example.py --validate

Validation

These scripts use pytest to verify the numerical correctness of each implementation against a standard PyTorch reference.

Iris

python tests/run_tests_distributed.py tests/examples/test_flash_decode.py --num_ranks 8

RCCL

python tests/run_tests_distributed.py examples/benchmark/reference/flash_decode_rccl/validate_flash_decode_rccl.py --num_ranks 8

Benchmarking

These scripts run a sweep of configurations and save performance results as .json files into the results/ directory.

Iris

python benchmark/examples/benchmark_flash_decode.py --num_ranks 8

RCCL

torchrun --nproc_per_node=8 examples/benchmark/reference/flash_decode_rccl/benchmark_flash_decode_rccl.py