Skip to content

BlinkDL/Albatross

Repository files navigation

Albatross : efficient RWKV inference engine

UPDATE: faster3_2605 can reach 17000+ tps prefill (B1T1024), 15000+ tps decode (B1024T1), 21000+ tps batch prefill (B32T32), on single 5090.

UPDATE: faster3a_2605 (currently fastest) is up to 40% faster than faster3_2605 for small B/T (for better performance, tune linear_orig_layout for your GPU).

RESULT B=1 T=1 iters=3 p10_ms=6.9425 p50_ms=6.9427 p90_ms=7.1073 tok_s_p50=144.04
RESULT B=1 T=2 iters=3 p10_ms=7.2224 p50_ms=7.2231 p90_ms=7.3045 tok_s_p50=276.89
RESULT B=1 T=4 iters=3 p10_ms=7.8479 p50_ms=7.8638 p90_ms=8.0480 tok_s_p50=508.66
RESULT B=1 T=8 iters=3 p10_ms=8.9945 p50_ms=8.9973 p90_ms=9.0790 tok_s_p50=889.15
RESULT B=1 T=16 iters=3 p10_ms=9.2388 p50_ms=9.2642 p90_ms=9.3825 tok_s_p50=1727.09
RESULT B=1 T=32 iters=3 p10_ms=11.1926 p50_ms=11.1940 p90_ms=11.4933 tok_s_p50=2858.66
RESULT B=1 T=64 iters=3 p10_ms=11.6656 p50_ms=11.6670 p90_ms=11.9468 tok_s_p50=5485.54
RESULT B=1 T=128 iters=3 p10_ms=13.4997 p50_ms=13.5012 p90_ms=13.6163 tok_s_p50=9480.67
RESULT B=1 T=256 iters=3 p10_ms=18.2705 p50_ms=18.2778 p90_ms=18.3811 tok_s_p50=14006.07
RESULT B=2 T=1 iters=3 p10_ms=7.2577 p50_ms=7.2684 p90_ms=7.3323 tok_s_p50=275.16
RESULT B=4 T=1 iters=3 p10_ms=7.9306 p50_ms=7.9442 p90_ms=8.0348 tok_s_p50=503.51
RESULT B=8 T=1 iters=3 p10_ms=8.7188 p50_ms=8.7593 p90_ms=8.9117 tok_s_p50=913.32
RESULT B=16 T=1 iters=3 p10_ms=9.3525 p50_ms=9.3743 p90_ms=9.6280 tok_s_p50=1706.79
RESULT B=32 T=1 iters=3 p10_ms=11.2196 p50_ms=11.2238 p90_ms=11.4337 tok_s_p50=2851.07
RESULT B=64 T=1 iters=3 p10_ms=11.6686 p50_ms=11.6814 p90_ms=11.8833 tok_s_p50=5478.79
RESULT B=128 T=1 iters=3 p10_ms=13.6054 p50_ms=13.6102 p90_ms=13.7000 tok_s_p50=9404.68
RESULT B=256 T=1 iters=3 p10_ms=19.4996 p50_ms=19.5026 p90_ms=19.6272 tok_s_p50=13126.46
RESULT B=2 T=2 iters=3 p10_ms=7.8615 p50_ms=7.8702 p90_ms=7.9935 tok_s_p50=508.25
RESULT B=4 T=4 iters=3 p10_ms=9.1181 p50_ms=9.1330 p90_ms=9.2556 tok_s_p50=1751.89
RESULT B=8 T=8 iters=3 p10_ms=11.0723 p50_ms=11.0758 p90_ms=11.2748 tok_s_p50=5778.38
RESULT B=16 T=16 iters=3 p10_ms=14.8019 p50_ms=14.8049 p90_ms=14.8631 tok_s_p50=17291.61

UPDATE: faster4_2605_cpp (faster for some BnTn, slower for some BnTn) as standalone (no libtorch, no python) C++ inference (for better performance, tune linear_orig_layout_launch for your GPU)


Demo: enter faster2_251201 (slower than v3a and v4) and run benchmark.py (fastest decode) and demo3.py (fastest batch decode) and demo4.py (write 120 webpages in parallel).

Note: demo3.py has efficient standalone Python GUI and you can simply run it on your GPU computer.

While for demo2.py, you have to SSH to the GPU computer to run demo2.py in a SSH session, such that the GPU won't be affected by slow terminal rendering.


Old Readme

Please check this first: https://github.com/BlinkDL/Albatross/blob/main/benchmark.py

Faster fwd & bwd CUDA kernels: https://github.com/BlinkDL/RWKV-CUDA/tree/main/rwkv7_fast_fused

Full backend: https://github.com/RWKV-Vibe/rwkv_lightning

Fast sampling: https://github.com/Triang-jyed-driung/Rapid-Sampling

Result @ 251201

145+ token/s RWKV-7 7.2B fp16 bsz1 @ RTX5090

11289 token/s RWKV-7 7.2B fp16 bsz1 prefill @ RTX5090

Code: https://github.com/Triang-jyed-driung/Albatross/tree/fp16

(enable torch.compile in https://github.com/Triang-jyed-driung/Albatross/blob/fp16/reference/rwkv7.py)

Result @ 251103

10250+ token/s RWKV-7 7.2B fp16 bsz960 @ RTX5090

9650+ token/s RWKV-7 7.2B fp16 bsz320 @ RTX5090

123+ token/s RWKV-7 7.2B fp16 bsz1 @ RTX5090 with CUDAGraph and sparse FFN (lossless)

Code: https://github.com/BlinkDL/Albatross/tree/main/faster_251101

Result @ 251007

1.3x 7B decoding and 5x 0.1B decoding, with CUDAGraph.

Result @ 250909

Now with batch inference. 7B fp16 bsz 320 = 5848 token/s decoding (const speed & vram because it's RNN) on 5090. I think 10000 token/s is achievable (full fp16 precision).

Result @ 250904

Baseline performance for RWKV-7 7.2B bsz=1 @ RTX5090, simply abysmal lol

Let me know if you can find simple methods (such as tuning torch.compile etc.) to improve these a bit

Token/s = 75.1 (forward), 73.76 (full) || Bandwidth = 1041.2 GB/s || 3.722s

CTX_LEN 512 : avg loss 1.6548 || prefill 9163 token/s = 127.03 TFLOPS
CTX_LEN 1024 : avg loss 1.5689 || prefill 9742 token/s = 135.06 TFLOPS
CTX_LEN 2048 : avg loss 1.5141 || prefill 10081 token/s = 139.76 TFLOPS
CTX_LEN 4096 : avg loss 1.4824 || prefill 10427 token/s = 144.55 TFLOPS

About

RWKV-7 7.2B fp16 15000+ tps decoding @ single 5090

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors