|
| 1 | +# Open-R1 |
| 2 | + |
| 3 | +This example demonstrates how to use `dstack` and [open-r1](https://github.com/huggingface/open-r1) to run GRPO training across distributed nodes, dedicating one node for [vLLM](https://github.com/vllm-project/vllm) inference while using the remaining nodes for training. |
| 4 | + |
| 5 | +??? info "Prerequisites" |
| 6 | + Once `dstack` is [installed](https://dstack.ai/docs/installation), go ahead clone the repo, and run `dstack init`. |
| 7 | + |
| 8 | + <div class="termy"> |
| 9 | + |
| 10 | + ```shell |
| 11 | + $ git clone https://github.com/dstackai/dstack |
| 12 | + $ cd dstack |
| 13 | + $ dstack init |
| 14 | + ``` |
| 15 | + </div> |
| 16 | + |
| 17 | +## Create fleet |
| 18 | + |
| 19 | +Before submitting distributed training runs, make sure to create a fleet with a `placement` set to `cluster`. |
| 20 | + |
| 21 | +> For more detials on how to use clusters with `dstack`, check the [Clusters](https://dstack.ai/docs/guides/clusters) guide. |
| 22 | +
|
| 23 | +## Define a configurtation |
| 24 | + |
| 25 | +Once the fleet is created, define a distributed task configuration. |
| 26 | + |
| 27 | +<div editor-title="examples/distributed-training/open-r1/.dstack.yml"> |
| 28 | + |
| 29 | +```yaml |
| 30 | +type: task |
| 31 | +name: open-r1-grpo |
| 32 | + |
| 33 | +# Size of the cluster |
| 34 | +nodes: 2 |
| 35 | + |
| 36 | +python: 3.12 |
| 37 | + |
| 38 | +nvcc: true |
| 39 | + |
| 40 | +# Required environment variables |
| 41 | +env: |
| 42 | + - HF_TOKEN |
| 43 | + - HUB_MODEL_ID |
| 44 | + - WANDB_API_KEY |
| 45 | + - NCCL_DEBUG=INFO |
| 46 | + # VLLM configuration |
| 47 | + - USE_VLLM=true |
| 48 | + - MODEL=Qwen/Qwen2.5-Coder-7B-Instruct |
| 49 | + - TP=4 |
| 50 | + - DP=2 |
| 51 | + |
| 52 | +# Commands of the task |
| 53 | +commands: |
| 54 | + - uv pip install vllm==0.8.5.post1 |
| 55 | + - uv pip install setuptools |
| 56 | + - uv pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl |
| 57 | + - git clone https://github.com/huggingface/open-r1.git |
| 58 | + - cd open-r1 |
| 59 | + - uv pip install . |
| 60 | + - | |
| 61 | + # Get the last IP from DSTACK_NODES_IPS for vLLM node |
| 62 | + VLLM_HOST=$(echo $DSTACK_NODES_IPS | tr ' ' '\n' | tail -n 1) |
| 63 | + echo "VLLM host IP (last node): $VLLM_HOST" |
| 64 | + |
| 65 | + if [ "$USE_VLLM" = "true" ]; then |
| 66 | + if [ "$DSTACK_NODE_RANK" -eq $(($DSTACK_NODES_NUM - 1)) ]; then |
| 67 | + # Last Node runs VLLM server |
| 68 | + echo "Starting VLLM server on Last Node (IP: $VLLM_HOST)" |
| 69 | + trl vllm-serve --model $MODEL --tensor_parallel_size $TP --data_parallel_size $DP --host 0.0.0.0 |
| 70 | + else |
| 71 | + # Training node - adjust world size and nodes count for training |
| 72 | + GPUS_PER_NODE=$(($DSTACK_GPUS_NUM / $DSTACK_NODES_NUM)) |
| 73 | + ADJUSTED_NODES_NUM=$(($DSTACK_NODES_NUM - 1)) |
| 74 | + ADJUSTED_GPUS_TOTAL=$(($GPUS_PER_NODE * $ADJUSTED_NODES_NUM)) |
| 75 | + # Other nodes run training |
| 76 | + echo "Starting training with VLLM on $VLLM_HOST" |
| 77 | + accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \ |
| 78 | + --num_processes=$ADJUSTED_GPUS_TOTAL \ |
| 79 | + --num_machines=$ADJUSTED_NODES_NUM \ |
| 80 | + --machine_rank=$DSTACK_NODE_RANK \ |
| 81 | + --main_process_ip=$DSTACK_MASTER_NODE_IP \ |
| 82 | + --main_process_port=8008 \ |
| 83 | + src/open_r1/grpo.py \ |
| 84 | + --config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml \ |
| 85 | + --model_name_or_path $MODEL \ |
| 86 | + --output_dir /checkpoints/Qwen2.5-Coder-7B-Instruct-GRPO \ |
| 87 | + --hub_model_id $HUB_MODEL_ID \ |
| 88 | + --vllm_server_host=$VLLM_HOST |
| 89 | + fi |
| 90 | + else |
| 91 | + # Standard training mode without VLLM |
| 92 | + echo "Running standard training without VLLM" |
| 93 | + accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \ |
| 94 | + --num_processes=$DSTACK_GPUS_NUM \ |
| 95 | + --num_machines=$DSTACK_NODES_NUM \ |
| 96 | + --machine_rank=$DSTACK_NODE_RANK \ |
| 97 | + --main_process_ip=$DSTACK_MASTER_NODE_IP \ |
| 98 | + --main_process_port=8008 \ |
| 99 | + src/open_r1/grpo.py \ |
| 100 | + --config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml \ |
| 101 | + --model_name_or_path $MODEL \ |
| 102 | + --output_dir /checkpoints/Qwen2.5-Coder-7B-Instruct-GRPO \ |
| 103 | + --hub_model_id $HUB_MODEL_ID \ |
| 104 | + --use_vllm false |
| 105 | + fi |
| 106 | +
|
| 107 | +resources: |
| 108 | + gpu: 80GB:8 |
| 109 | + shm_size: 128GB |
| 110 | + |
| 111 | +volumes: |
| 112 | + - /checkpoints:/checkpoints |
| 113 | +``` |
| 114 | +</div> |
| 115 | +
|
| 116 | +!!! info "NOTE:" |
| 117 | + 1. When `nodes: N` and `USE_VLLM=true`, (N-1) nodes will be used for distributed training, while last node will be used for vLLM inference. |
| 118 | + 2. When `USE_VLLM=false`, training happens across all nodes. |
| 119 | + 3. The number of `attention heads` of the model should be divisible by `TP` and `DP` values. |
| 120 | + 4. We pin the `flash-attn` to `2.7.4.post1` for compatibility with `vllm=0.8.5.post1`, as latest vLLM causes TRL's GRPO to fail. See [issue #3608](https://github.com/huggingface/trl/issues/3608) for details. |
| 121 | + |
| 122 | + |
| 123 | +### Apply the configuration |
| 124 | + |
| 125 | +To run a configuration, use the [`dstack apply`](https://dstack.ai/docs/reference/cli/dstack/apply.md) command. |
| 126 | + |
| 127 | +<div class="termy"> |
| 128 | + |
| 129 | +```shell |
| 130 | +$ HF_TOKEN=... |
| 131 | +$ HUB_MODEL_ID=... |
| 132 | +$ WANDB_API_KEY=... |
| 133 | +$ dstack apply -f examples/distributed-training/open-r1/.dstack.yml |
| 134 | +
|
| 135 | + # BACKEND RESOURCES INSTANCE TYPE PRICE |
| 136 | + 1 ssh (remote) cpu=208 mem=1772GB H100:80GB:8 instance $0 idle |
| 137 | + 2 ssh (remote) cpu=208 mem=1772GB H100:80GB:8 instance $0 idle |
| 138 | + |
| 139 | +Submit the run open-r1-grpo? [y/n]: y |
| 140 | +
|
| 141 | +Provisioning... |
| 142 | +---> 100% |
| 143 | +``` |
| 144 | +</div> |
| 145 | + |
| 146 | +## Source code |
| 147 | + |
| 148 | +The source-code of this example can be found in |
| 149 | +[`examples/distributed-training/open-r1` :material-arrow-top-right-thin:{ .external }](https://github.com/dstackai/dstack/blob/master/examples/distributed-training/open-r1){:target="_blank"}. |
| 150 | + |
| 151 | +!!! info "What's next?" |
| 152 | + 1. Read the [clusters](https://dstack.ai/docs/guides/clusters) guide |
| 153 | + 2. Check [dev environments](https://dstack.ai/docs/concepts/dev-environments), [tasks](https://dstack.ai/docs/concepts/tasks), |
| 154 | + [services](https://dstack.ai/docs/concepts/services), and [fleets](https://dstack.ai/docs/concepts/fleets) |
| 155 | + |
0 commit comments