Skip to content

Commit 9073229

Browse files
Bihan  RanaBihan  Rana
authored andcommitted
Add Open-R1 Example
1 parent 5a5b134 commit 9073229

5 files changed

Lines changed: 249 additions & 0 deletions

File tree

docs/examples.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,16 @@ hide:
7474
with RAGEN, verl, and Ray.
7575
</p>
7676
</a>
77+
<a href="/examples/distributed-training/open-r1"
78+
class="feature-cell sky">
79+
<h3>
80+
Open-R1
81+
</h3>
82+
83+
<p>
84+
Fine-tune to reproduce Deepseek-R1 with separate inference and training nodes.
85+
</p>
86+
</a>
7787
</div>
7888

7989

docs/examples/distributed-training/open-r1/index.md

Whitespace-only changes.
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
type: task
2+
name: open-r1-grpo
3+
4+
# Size of the cluster
5+
nodes: 2
6+
7+
python: 3.12
8+
9+
nvcc: true
10+
11+
# Required environment variables
12+
env:
13+
- HF_TOKEN
14+
- WANDB_API_KEY
15+
- NCCL_DEBUG=INFO
16+
# VLLM configuration
17+
- USE_VLLM=true
18+
- MODEL=Qwen/Qwen2.5-Coder-7B-Instruct
19+
# Qwen2.5-Coder-7B-Instruct has 28 attention heads and should be divisible by TP and DP
20+
- TP=4
21+
- DP=2
22+
23+
# Commands of the task
24+
commands:
25+
- uv pip install vllm==0.8.5.post1
26+
- uv pip install setuptools
27+
- uv pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl
28+
- git clone https://github.com/huggingface/open-r1.git
29+
- cd open-r1
30+
- uv pip install .
31+
- |
32+
# Get the last IP from DSTACK_NODES_IPS for vLLM node
33+
VLLM_HOST=$(echo $DSTACK_NODES_IPS | tr ' ' '\n' | tail -n 1)
34+
echo "VLLM host IP (last node): $VLLM_HOST"
35+
36+
if [ "$USE_VLLM" = "true" ]; then
37+
if [ "$DSTACK_NODE_RANK" -eq $(($DSTACK_NODES_NUM - 1)) ]; then
38+
# Last Node runs VLLM server
39+
echo "Starting VLLM server on Last Node (IP: $VLLM_HOST)"
40+
trl vllm-serve --model $MODEL --tensor_parallel_size $TP --data_parallel_size $DP --host 0.0.0.0
41+
else
42+
# Training node - adjust world size and nodes count for training
43+
GPUS_PER_NODE=$(($DSTACK_GPUS_NUM / $DSTACK_NODES_NUM))
44+
ADJUSTED_NODES_NUM=$(($DSTACK_NODES_NUM - 1))
45+
ADJUSTED_GPUS_TOTAL=$(($GPUS_PER_NODE * $ADJUSTED_NODES_NUM))
46+
# Other nodes run training
47+
echo "Starting training with VLLM on $VLLM_HOST"
48+
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
49+
--num_processes=$ADJUSTED_GPUS_TOTAL \
50+
--num_machines=$ADJUSTED_NODES_NUM \
51+
--machine_rank=$DSTACK_NODE_RANK \
52+
--main_process_ip=$DSTACK_MASTER_NODE_IP \
53+
--main_process_port=8008 \
54+
src/open_r1/grpo.py \
55+
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml \
56+
--model_name_or_path $MODEL \
57+
--output_dir /checkpoints/Qwen2.5-Coder-7B-Instruct-GRPO \
58+
--hub_model_id sjbbihan/Qwen2.5-Coder-7B-Instruct \
59+
--vllm_server_host=$VLLM_HOST
60+
fi
61+
else
62+
# Standard training mode without VLLM
63+
echo "Running standard training without VLLM"
64+
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
65+
--num_processes=$DSTACK_GPUS_NUM \
66+
--num_machines=$DSTACK_NODES_NUM \
67+
--machine_rank=$DSTACK_NODE_RANK \
68+
--main_process_ip=$DSTACK_MASTER_NODE_IP \
69+
--main_process_port=8008 \
70+
src/open_r1/grpo.py \
71+
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml \
72+
--model_name_or_path $MODEL \
73+
--output_dir /checkpoints/Qwen2.5-Coder-7B-Instruct-GRPO \
74+
--hub_model_id sjbbihan/Qwen2.5-Coder-7B-Instruct \
75+
--use_vllm false
76+
fi
77+
78+
resources:
79+
gpu: 80GB:8
80+
shm_size: 128GB
81+
82+
volumes:
83+
- /checkpoints:/checkpoints
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Open-R1
2+
3+
This example demonstrates how to use `dstack` and [open-r1](https://github.com/huggingface/open-r1) to run GRPO training across distributed nodes, dedicating one node for [vLLM](https://github.com/vllm-project/vllm) inference while using the remaining nodes for training.
4+
5+
??? info "Prerequisites"
6+
Once `dstack` is [installed](https://dstack.ai/docs/installation), go ahead clone the repo, and run `dstack init`.
7+
8+
<div class="termy">
9+
10+
```shell
11+
$ git clone https://github.com/dstackai/dstack
12+
$ cd dstack
13+
$ dstack init
14+
```
15+
</div>
16+
17+
## Create fleet
18+
19+
Before submitting distributed training runs, make sure to create a fleet with a `placement` set to `cluster`.
20+
21+
> For more detials on how to use clusters with `dstack`, check the [Clusters](https://dstack.ai/docs/guides/clusters) guide.
22+
23+
## Define a configurtation
24+
25+
Once the fleet is created, define a distributed task configuration.
26+
27+
<div editor-title="examples/distributed-training/open-r1/.dstack.yml">
28+
29+
```yaml
30+
type: task
31+
name: open-r1-grpo
32+
33+
# Size of the cluster
34+
nodes: 2
35+
36+
python: 3.12
37+
38+
nvcc: true
39+
40+
# Required environment variables
41+
env:
42+
- HF_TOKEN
43+
- HUB_MODEL_ID
44+
- WANDB_API_KEY
45+
- NCCL_DEBUG=INFO
46+
# VLLM configuration
47+
- USE_VLLM=true
48+
- MODEL=Qwen/Qwen2.5-Coder-7B-Instruct
49+
- TP=4
50+
- DP=2
51+
52+
# Commands of the task
53+
commands:
54+
- uv pip install vllm==0.8.5.post1
55+
- uv pip install setuptools
56+
- uv pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl
57+
- git clone https://github.com/huggingface/open-r1.git
58+
- cd open-r1
59+
- uv pip install .
60+
- |
61+
# Get the last IP from DSTACK_NODES_IPS for vLLM node
62+
VLLM_HOST=$(echo $DSTACK_NODES_IPS | tr ' ' '\n' | tail -n 1)
63+
echo "VLLM host IP (last node): $VLLM_HOST"
64+
65+
if [ "$USE_VLLM" = "true" ]; then
66+
if [ "$DSTACK_NODE_RANK" -eq $(($DSTACK_NODES_NUM - 1)) ]; then
67+
# Last Node runs VLLM server
68+
echo "Starting VLLM server on Last Node (IP: $VLLM_HOST)"
69+
trl vllm-serve --model $MODEL --tensor_parallel_size $TP --data_parallel_size $DP --host 0.0.0.0
70+
else
71+
# Training node - adjust world size and nodes count for training
72+
GPUS_PER_NODE=$(($DSTACK_GPUS_NUM / $DSTACK_NODES_NUM))
73+
ADJUSTED_NODES_NUM=$(($DSTACK_NODES_NUM - 1))
74+
ADJUSTED_GPUS_TOTAL=$(($GPUS_PER_NODE * $ADJUSTED_NODES_NUM))
75+
# Other nodes run training
76+
echo "Starting training with VLLM on $VLLM_HOST"
77+
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
78+
--num_processes=$ADJUSTED_GPUS_TOTAL \
79+
--num_machines=$ADJUSTED_NODES_NUM \
80+
--machine_rank=$DSTACK_NODE_RANK \
81+
--main_process_ip=$DSTACK_MASTER_NODE_IP \
82+
--main_process_port=8008 \
83+
src/open_r1/grpo.py \
84+
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml \
85+
--model_name_or_path $MODEL \
86+
--output_dir /checkpoints/Qwen2.5-Coder-7B-Instruct-GRPO \
87+
--hub_model_id $HUB_MODEL_ID \
88+
--vllm_server_host=$VLLM_HOST
89+
fi
90+
else
91+
# Standard training mode without VLLM
92+
echo "Running standard training without VLLM"
93+
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
94+
--num_processes=$DSTACK_GPUS_NUM \
95+
--num_machines=$DSTACK_NODES_NUM \
96+
--machine_rank=$DSTACK_NODE_RANK \
97+
--main_process_ip=$DSTACK_MASTER_NODE_IP \
98+
--main_process_port=8008 \
99+
src/open_r1/grpo.py \
100+
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml \
101+
--model_name_or_path $MODEL \
102+
--output_dir /checkpoints/Qwen2.5-Coder-7B-Instruct-GRPO \
103+
--hub_model_id $HUB_MODEL_ID \
104+
--use_vllm false
105+
fi
106+
107+
resources:
108+
gpu: 80GB:8
109+
shm_size: 128GB
110+
111+
volumes:
112+
- /checkpoints:/checkpoints
113+
```
114+
</div>
115+
116+
!!! info "NOTE:"
117+
1. When `nodes: N` and `USE_VLLM=true`, (N-1) nodes will be used for distributed training, while last node will be used for vLLM inference.
118+
2. When `USE_VLLM=false`, training happens across all nodes.
119+
3. The number of `attention heads` of the model should be divisible by `TP` and `DP` values.
120+
4. We pin the `flash-attn` to `2.7.4.post1` for compatibility with `vllm=0.8.5.post1`, as latest vLLM causes TRL's GRPO to fail. See [issue #3608](https://github.com/huggingface/trl/issues/3608) for details.
121+
122+
123+
### Apply the configuration
124+
125+
To run a configuration, use the [`dstack apply`](https://dstack.ai/docs/reference/cli/dstack/apply.md) command.
126+
127+
<div class="termy">
128+
129+
```shell
130+
$ HF_TOKEN=...
131+
$ HUB_MODEL_ID=...
132+
$ WANDB_API_KEY=...
133+
$ dstack apply -f examples/distributed-training/open-r1/.dstack.yml
134+
135+
# BACKEND RESOURCES INSTANCE TYPE PRICE
136+
1 ssh (remote) cpu=208 mem=1772GB H100:80GB:8 instance $0 idle
137+
2 ssh (remote) cpu=208 mem=1772GB H100:80GB:8 instance $0 idle
138+
139+
Submit the run open-r1-grpo? [y/n]: y
140+
141+
Provisioning...
142+
---> 100%
143+
```
144+
</div>
145+
146+
## Source code
147+
148+
The source-code of this example can be found in
149+
[`examples/distributed-training/open-r1` :material-arrow-top-right-thin:{ .external }](https://github.com/dstackai/dstack/blob/master/examples/distributed-training/open-r1){:target="_blank"}.
150+
151+
!!! info "What's next?"
152+
1. Read the [clusters](https://dstack.ai/docs/guides/clusters) guide
153+
2. Check [dev environments](https://dstack.ai/docs/concepts/dev-environments), [tasks](https://dstack.ai/docs/concepts/tasks),
154+
[services](https://dstack.ai/docs/concepts/services), and [fleets](https://dstack.ai/docs/concepts/fleets)
155+

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ nav:
268268
- TRL: examples/distributed-training/trl/index.md
269269
- Axolotl: examples/distributed-training/axolotl/index.md
270270
- Ray+RAGEN: examples/distributed-training/ray-ragen/index.md
271+
- Open-R1: examples/distributed-training/open-r1/index.md
271272
- Clusters:
272273
- NCCL tests: examples/clusters/nccl-tests/index.md
273274
- RCCL tests: examples/clusters/rccl-tests/index.md

0 commit comments

Comments
 (0)