Skip to content

Commit a872571

Browse files
mukesh reddypmukeshreddy
authored andcommitted
Add missing modified files for SGLang integration
1 parent 486365c commit a872571

2 files changed

Lines changed: 198 additions & 52 deletions

File tree

src/art/unsloth/service.py

Lines changed: 187 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Unsloth training service with decoupled vLLM inference."""
22

33
import asyncio
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from functools import cached_property
66
import os
77
from typing import TYPE_CHECKING, Any, AsyncIterator, Protocol, cast
@@ -29,6 +29,76 @@
2929
from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers
3030
from .train import gc_and_empty_cuda_cache, train
3131

32+
33+
# ============================================================================
34+
# Device Configuration for Multi-GPU Support
35+
# ============================================================================
36+
37+
38+
@dataclass
39+
class DeviceConfig:
40+
"""GPU device assignment for Unsloth training and vLLM inference.
41+
42+
For optimal performance, training and inference should run on separate GPUs.
43+
This eliminates memory contention and the need for CPU offloading.
44+
45+
Attributes:
46+
inference_device: GPU index for vLLM inference (default: 0)
47+
training_device: GPU index for Unsloth training (default: 1, or 0 if single GPU)
48+
auto_detect: If True, automatically detect available GPUs
49+
50+
Example:
51+
# 2-GPU setup (recommended)
52+
config = DeviceConfig(inference_device=0, training_device=1)
53+
54+
# Single GPU (fallback with CPU offloading)
55+
config = DeviceConfig(inference_device=0, training_device=0)
56+
"""
57+
inference_device: int = 0
58+
training_device: int = 1
59+
auto_detect: bool = True
60+
61+
def __post_init__(self):
62+
if self.auto_detect:
63+
self._auto_configure()
64+
65+
def _auto_configure(self):
66+
"""Auto-detect GPU count and configure devices."""
67+
try:
68+
gpu_count = torch.cuda.device_count()
69+
except Exception:
70+
gpu_count = 1
71+
72+
if gpu_count == 0:
73+
raise RuntimeError("No CUDA GPUs available.")
74+
elif gpu_count == 1:
75+
# Single GPU: shared mode (will use CPU offloading)
76+
self.inference_device = 0
77+
self.training_device = 0
78+
print(f"[DeviceConfig] Single GPU detected. Using shared mode with CPU offloading.")
79+
else:
80+
# Multi-GPU: split mode (no offloading needed!)
81+
self.inference_device = 0
82+
self.training_device = 1
83+
print(f"[DeviceConfig] {gpu_count} GPUs detected. Using split mode:")
84+
print(f" - GPU {self.inference_device}: vLLM inference")
85+
print(f" - GPU {self.training_device}: Unsloth training")
86+
87+
@property
88+
def is_split_mode(self) -> bool:
89+
"""True if inference and training use separate GPUs."""
90+
return self.inference_device != self.training_device
91+
92+
@property
93+
def inference_cuda_devices(self) -> str:
94+
"""CUDA_VISIBLE_DEVICES string for vLLM inference subprocess."""
95+
return str(self.inference_device)
96+
97+
@property
98+
def training_cuda_device(self) -> str:
99+
"""CUDA device string for training (e.g., 'cuda:1')."""
100+
return f"cuda:{self.training_device}"
101+
32102
if TYPE_CHECKING:
33103
from peft.peft_model import PeftModelForCausalLM
34104
from trl import GRPOTrainer
@@ -174,79 +244,54 @@ class UnslothState:
174244
_pinned_buffers: dict[str, torch.Tensor] | None = None
175245

176246
def offload_to_cpu(self) -> None:
177-
"""Offload training model and optimizer to CPU using pinned memory for faster transfers."""
247+
"""Offload entire training model (base + adapters) and optimizer to CPU."""
178248
if self._is_offloaded:
179249
return
180250

181-
# Initialize pinned buffer storage
182-
if self._pinned_buffers is None:
183-
self._pinned_buffers = {}
184-
185-
# Offload model parameters to pinned memory for faster reload
186-
for name, param in self.peft_model.named_parameters():
187-
if param.device.type == "cuda":
188-
# Create pinned buffer if not exists or wrong size
189-
if (
190-
name not in self._pinned_buffers
191-
or self._pinned_buffers[name].shape != param.shape
192-
):
193-
self._pinned_buffers[name] = torch.empty(
194-
param.shape, dtype=param.dtype, device="cpu", pin_memory=True
195-
)
196-
# Async copy to pinned memory
197-
self._pinned_buffers[name].copy_(param.data, non_blocking=True)
198-
param.data = self._pinned_buffers[name]
199-
200-
# Offload optimizer state to pinned memory
251+
print("[UnslothService] Offloading entire model to CPU...")
252+
253+
# Move the entire PEFT model to CPU (this includes base model + adapters)
254+
self.peft_model.to("cpu")
255+
256+
# Offload optimizer state to CPU
201257
optimizer = getattr(self.trainer, "optimizer", None)
202258
if optimizer is not None and hasattr(optimizer, "state"):
203259
for param_id, state in optimizer.state.items():
204260
for k, v in state.items():
205261
if isinstance(v, torch.Tensor) and v.device.type == "cuda":
206-
key = f"opt_{id(param_id)}_{k}"
207-
if (
208-
key not in self._pinned_buffers
209-
or self._pinned_buffers[key].shape != v.shape
210-
):
211-
self._pinned_buffers[key] = torch.empty(
212-
v.shape, dtype=v.dtype, device="cpu", pin_memory=True
213-
)
214-
self._pinned_buffers[key].copy_(v, non_blocking=True)
215-
state[k] = self._pinned_buffers[key]
216-
217-
# Sync to ensure all copies are complete before freeing GPU memory
218-
torch.cuda.synchronize()
262+
state[k] = v.cpu()
219263

264+
# Sync and clear GPU memory
265+
torch.cuda.synchronize()
220266
self._is_offloaded = True
221267
gc_and_empty_cuda_cache()
268+
269+
# Report free memory
270+
free_mem = torch.cuda.mem_get_info()[0] / 1e9
271+
print(f"[UnslothService] Model offloaded. GPU memory free: {free_mem:.2f} GB")
222272

223273
def reload_to_gpu(self, device: str = "cuda:0") -> None:
224-
"""Reload training model and optimizer back to GPU using async transfers."""
274+
"""Reload entire training model and optimizer back to GPU."""
225275
if not self._is_offloaded:
226276
return
227277

228-
# Reload model parameters from pinned memory (fast async transfer)
229-
for name, param in self.peft_model.named_parameters():
230-
if param.device.type == "cpu":
231-
# Allocate on GPU and async copy from pinned memory
232-
gpu_tensor = torch.empty(param.shape, dtype=param.dtype, device=device)
233-
gpu_tensor.copy_(param.data, non_blocking=True)
234-
param.data = gpu_tensor
278+
print(f"[UnslothService] Reloading model to {device}...")
279+
280+
# Move the entire PEFT model back to GPU
281+
self.peft_model.to(device)
235282

236-
# Reload optimizer state
283+
# Reload optimizer state to GPU
237284
optimizer = getattr(self.trainer, "optimizer", None)
238285
if optimizer is not None and hasattr(optimizer, "state"):
239286
for state in optimizer.state.values():
240287
for k, v in state.items():
241288
if isinstance(v, torch.Tensor) and v.device.type == "cpu":
242-
gpu_tensor = torch.empty(v.shape, dtype=v.dtype, device=device)
243-
gpu_tensor.copy_(v, non_blocking=True)
244-
state[k] = gpu_tensor
289+
state[k] = v.to(device)
245290

246291
# Sync to ensure all copies are complete before training
247292
torch.cuda.synchronize()
248-
249293
self._is_offloaded = False
294+
print(f"[UnslothService] Model reloaded to {device}")
250295

251296

252297
# ============================================================================
@@ -260,6 +305,7 @@ class UnslothService:
260305
base_model: str
261306
config: dev.InternalModelConfig
262307
output_dir: str
308+
device_config: DeviceConfig = field(default_factory=DeviceConfig)
263309
_is_sleeping: bool = False
264310
_latest_step: int = 0
265311
_lora_id_counter: int = 1 # Start from 1 since 0 is reserved
@@ -283,8 +329,13 @@ async def start_openai_server(
283329
# Extract step from checkpoint path
284330
self._latest_step = get_step_from_dir(self.output_dir)
285331

286-
# Offload training model to CPU before vLLM starts to free GPU memory
332+
# Offload training model to CPU so vLLM can use the GPU
287333
self._state.offload_to_cpu()
334+
# Force garbage collection and clear CUDA cache
335+
import gc
336+
gc.collect()
337+
torch.cuda.empty_cache()
338+
torch.cuda.synchronize()
288339

289340
server_config = dev.get_openai_server_config(
290341
model_name=self.model_name,
@@ -334,7 +385,7 @@ async def train(
334385
) -> AsyncIterator[dict[str, float]]:
335386
llm = await self.llm
336387

337-
# Pause generation to prevent new requests during training
388+
# Time-sharing mode: pause vLLM, free GPU memory, then train
338389
await llm.pause_generation()
339390

340391
# Determine sleep level based on outstanding requests:
@@ -364,10 +415,14 @@ async def train(
364415

365416
# If we haven't already, start the training task
366417
if not hasattr(self, "_train_task") or self._train_task is None:
418+
# Use remapped device index: in split mode with CUDA_VISIBLE_DEVICES=0,1,
419+
# training is cuda:1 (second visible device)
420+
# Training device is cuda:0
367421
self._train_task = asyncio.create_task(
368422
train(
369423
trainer=self._state.trainer,
370424
results_queue=self._state.results_queue,
425+
training_device=0,
371426
)
372427
)
373428
warmup = True
@@ -396,7 +451,7 @@ async def train(
396451
verbose=verbose,
397452
)
398453

399-
# Offload training model to CPU before waking vLLM
454+
# Offload training model before waking vLLM
400455
self._state.offload_to_cpu()
401456

402457
# Free memory before waking up vLLM
@@ -438,18 +493,32 @@ async def train(
438493
def _state(self) -> UnslothState:
439494
import unsloth
440495

496+
# Use cuda:0 for training - Unsloth's compiled code expects this
497+
# Time-sharing with vLLM via sleep/wake handles memory management
498+
cuda_device_index = 0
499+
torch.cuda.set_device(cuda_device_index)
500+
print(f"[UnslothService] Loading training model on cuda:{cuda_device_index}")
501+
441502
# Initialize Unsloth model
442503
init_args = self.config.get("init_args", {})
443504
checkpoint_dir = get_last_checkpoint_dir(self.output_dir)
444505
if checkpoint_dir:
445506
init_args["model_name"] = checkpoint_dir
446507
else:
447508
init_args["model_name"] = self.base_model
509+
510+
# Set device_map to cuda:0 - Unsloth expects training on cuda:0
511+
if "device_map" not in init_args:
512+
init_args["device_map"] = {"": 0}
448513

449514
model, tokenizer = cast(
450515
tuple[CausalLM, PreTrainedTokenizerBase],
451516
unsloth.FastLanguageModel.from_pretrained(**init_args),
452517
)
518+
519+
# Verify the model is on the correct device
520+
model_device = next(model.parameters()).device
521+
print(f"[UnslothService] Model loaded on device: {model_device}, current_device={torch.cuda.current_device()}")
453522

454523
# Initialize PEFT model - skip if already a PeftModel (e.g. loaded from checkpoint)
455524
if (
@@ -466,6 +535,56 @@ def _state(self) -> UnslothState:
466535
),
467536
)
468537

538+
# Reset AcceleratorState singleton and patch device check before creating trainer
539+
# This is necessary because AcceleratorState caches the device from first initialization,
540+
# which might have been device 0 (from vLLM or imports). We need it to use device 1.
541+
try:
542+
from accelerate.state import AcceleratorState
543+
from accelerate import Accelerator
544+
AcceleratorState._reset_state()
545+
546+
# Monkey-patch Accelerator to skip device check for 4-bit models
547+
# The check fails when model is on GPU 1 but Accelerator was initialized earlier
548+
# We need to bypass the check BEFORE original_prepare_model runs
549+
original_prepare_model = Accelerator.prepare_model
550+
def patched_prepare_model(self, model, device_placement=None, evaluation_mode=False):
551+
# For quantized models, temporarily remove the quantization flags to bypass the check
552+
# Then restore them after prepare_model completes
553+
was_8bit = getattr(model, "is_loaded_in_8bit", False)
554+
was_4bit = getattr(model, "is_loaded_in_4bit", False)
555+
was_device_map = getattr(model, "hf_device_map", None)
556+
557+
if was_8bit or was_4bit:
558+
print(f"[UnslothService] Temporarily hiding quantization flags to bypass device check")
559+
# Temporarily hide the quantization flags
560+
model.is_loaded_in_8bit = False
561+
model.is_loaded_in_4bit = False
562+
# Try to delete hf_device_map - it may be on inner model (accessible via __getattr__)
563+
# but not directly deletable from the PEFT wrapper
564+
try:
565+
delattr(model, "hf_device_map")
566+
except AttributeError:
567+
pass # Attribute is on inner model, not directly on PEFT wrapper
568+
569+
try:
570+
result = original_prepare_model(self, model, device_placement, evaluation_mode)
571+
finally:
572+
# Restore the flags
573+
if was_8bit:
574+
model.is_loaded_in_8bit = True
575+
if was_4bit:
576+
model.is_loaded_in_4bit = True
577+
if was_device_map is not None:
578+
model.hf_device_map = was_device_map
579+
return result
580+
else:
581+
return original_prepare_model(self, model, device_placement, evaluation_mode)
582+
Accelerator.prepare_model = patched_prepare_model
583+
584+
print(f"[UnslothService] Reset AcceleratorState and patched prepare_model, current_device={torch.cuda.current_device()}")
585+
except Exception as e:
586+
print(f"[UnslothService] Could not reset AcceleratorState: {e}")
587+
469588
# Initialize trainer with dummy dataset
470589
data = {"prompt": ""}
471590
trainer = GRPOTrainer(
@@ -504,12 +623,29 @@ async def get_inputs() -> TrainInputs:
504623

505624
@cached_property
506625
def llm(self) -> asyncio.Task[AsyncLLM]:
626+
# Use single GPU (cuda:0) for both vLLM and Unsloth with time-sharing
627+
# Unsloth's compiled training loop expects cuda:0, so split-GPU mode is not supported
628+
inference_gpu = self.device_config.inference_device
629+
os.environ["CUDA_VISIBLE_DEVICES"] = str(inference_gpu)
630+
print(f"[UnslothService] Starting vLLM on GPU {inference_gpu} (time-sharing mode with Unsloth)")
631+
507632
# Filter engine args to remove incompatible boolean flags
508633
engine_args = {
509634
**self.config.get("engine_args", {}),
510635
"enable_lora": True,
511636
"max_loras": self.config.get("engine_args", {}).get("max_loras", 2),
512637
}
638+
639+
# In split mode, vLLM has the full GPU to itself, so use high utilization
640+
# In shared mode, use lower utilization to leave room for training model
641+
if self.device_config.is_split_mode:
642+
if "gpu_memory_utilization" not in engine_args:
643+
engine_args["gpu_memory_utilization"] = 0.90
644+
else:
645+
# Shared mode: lower utilization to coexist with training
646+
if "gpu_memory_utilization" not in engine_args:
647+
engine_args["gpu_memory_utilization"] = 0.80
648+
513649
# Remove boolean flags that vLLM's argparse doesn't accept as =False
514650
for key in ["enable_log_requests", "disable_log_requests"]:
515651
engine_args.pop(key, None)

src/art/unsloth/train.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@
2323
async def train(
2424
trainer: "GRPOTrainer",
2525
results_queue: asyncio.Queue[dict[str, float]],
26+
training_device: int | None = None,
2627
) -> None:
28+
# Set the CUDA device before training - required for 4-bit/8-bit quantized models
29+
# because accelerate checks torch.cuda.current_device() matches the model's device
30+
if training_device is not None:
31+
torch.cuda.set_device(training_device)
32+
print(f"[train] Set CUDA device to {training_device}, current_device={torch.cuda.current_device()}")
33+
2734
_compute_loss = trainer.compute_loss
2835
_log = trainer.log
2936
trainer.compute_loss = get_compute_loss_fn(trainer)
@@ -37,7 +44,10 @@ async def train(
3744
if not is_train_dict:
3845
trainer._metrics = {"train": defaultdict(list)}
3946
try:
40-
trainer.train()
47+
# Use context manager to ensure device is set during training
48+
with torch.cuda.device(training_device) if training_device is not None else nullcontext():
49+
print(f"[train] About to call trainer.train(), current_device={torch.cuda.current_device()}")
50+
trainer.train()
4151
finally:
4252
trainer.compute_loss = _compute_loss
4353
trainer.log = _log # ty:ignore[invalid-assignment]

0 commit comments

Comments
 (0)