11"""Unsloth training service with decoupled vLLM inference."""
22
33import asyncio
4- from dataclasses import dataclass
4+ from dataclasses import dataclass , field
55from functools import cached_property
66import os
77from typing import TYPE_CHECKING , Any , AsyncIterator , Protocol , cast
2929from ..vllm import get_llm , get_worker , openai_server_task , run_on_workers
3030from .train import gc_and_empty_cuda_cache , train
3131
32+
33+ # ============================================================================
34+ # Device Configuration for Multi-GPU Support
35+ # ============================================================================
36+
37+
38+ @dataclass
39+ class DeviceConfig :
40+ """GPU device assignment for Unsloth training and vLLM inference.
41+
42+ For optimal performance, training and inference should run on separate GPUs.
43+ This eliminates memory contention and the need for CPU offloading.
44+
45+ Attributes:
46+ inference_device: GPU index for vLLM inference (default: 0)
47+ training_device: GPU index for Unsloth training (default: 1, or 0 if single GPU)
48+ auto_detect: If True, automatically detect available GPUs
49+
50+ Example:
51+ # 2-GPU setup (recommended)
52+ config = DeviceConfig(inference_device=0, training_device=1)
53+
54+ # Single GPU (fallback with CPU offloading)
55+ config = DeviceConfig(inference_device=0, training_device=0)
56+ """
57+ inference_device : int = 0
58+ training_device : int = 1
59+ auto_detect : bool = True
60+
61+ def __post_init__ (self ):
62+ if self .auto_detect :
63+ self ._auto_configure ()
64+
65+ def _auto_configure (self ):
66+ """Auto-detect GPU count and configure devices."""
67+ try :
68+ gpu_count = torch .cuda .device_count ()
69+ except Exception :
70+ gpu_count = 1
71+
72+ if gpu_count == 0 :
73+ raise RuntimeError ("No CUDA GPUs available." )
74+ elif gpu_count == 1 :
75+ # Single GPU: shared mode (will use CPU offloading)
76+ self .inference_device = 0
77+ self .training_device = 0
78+ print (f"[DeviceConfig] Single GPU detected. Using shared mode with CPU offloading." )
79+ else :
80+ # Multi-GPU: split mode (no offloading needed!)
81+ self .inference_device = 0
82+ self .training_device = 1
83+ print (f"[DeviceConfig] { gpu_count } GPUs detected. Using split mode:" )
84+ print (f" - GPU { self .inference_device } : vLLM inference" )
85+ print (f" - GPU { self .training_device } : Unsloth training" )
86+
87+ @property
88+ def is_split_mode (self ) -> bool :
89+ """True if inference and training use separate GPUs."""
90+ return self .inference_device != self .training_device
91+
92+ @property
93+ def inference_cuda_devices (self ) -> str :
94+ """CUDA_VISIBLE_DEVICES string for vLLM inference subprocess."""
95+ return str (self .inference_device )
96+
97+ @property
98+ def training_cuda_device (self ) -> str :
99+ """CUDA device string for training (e.g., 'cuda:1')."""
100+ return f"cuda:{ self .training_device } "
101+
32102if TYPE_CHECKING :
33103 from peft .peft_model import PeftModelForCausalLM
34104 from trl import GRPOTrainer
@@ -174,79 +244,54 @@ class UnslothState:
174244 _pinned_buffers : dict [str , torch .Tensor ] | None = None
175245
176246 def offload_to_cpu (self ) -> None :
177- """Offload training model and optimizer to CPU using pinned memory for faster transfers ."""
247+ """Offload entire training model (base + adapters) and optimizer to CPU."""
178248 if self ._is_offloaded :
179249 return
180250
181- # Initialize pinned buffer storage
182- if self ._pinned_buffers is None :
183- self ._pinned_buffers = {}
184-
185- # Offload model parameters to pinned memory for faster reload
186- for name , param in self .peft_model .named_parameters ():
187- if param .device .type == "cuda" :
188- # Create pinned buffer if not exists or wrong size
189- if (
190- name not in self ._pinned_buffers
191- or self ._pinned_buffers [name ].shape != param .shape
192- ):
193- self ._pinned_buffers [name ] = torch .empty (
194- param .shape , dtype = param .dtype , device = "cpu" , pin_memory = True
195- )
196- # Async copy to pinned memory
197- self ._pinned_buffers [name ].copy_ (param .data , non_blocking = True )
198- param .data = self ._pinned_buffers [name ]
199-
200- # Offload optimizer state to pinned memory
251+ print ("[UnslothService] Offloading entire model to CPU..." )
252+
253+ # Move the entire PEFT model to CPU (this includes base model + adapters)
254+ self .peft_model .to ("cpu" )
255+
256+ # Offload optimizer state to CPU
201257 optimizer = getattr (self .trainer , "optimizer" , None )
202258 if optimizer is not None and hasattr (optimizer , "state" ):
203259 for param_id , state in optimizer .state .items ():
204260 for k , v in state .items ():
205261 if isinstance (v , torch .Tensor ) and v .device .type == "cuda" :
206- key = f"opt_{ id (param_id )} _{ k } "
207- if (
208- key not in self ._pinned_buffers
209- or self ._pinned_buffers [key ].shape != v .shape
210- ):
211- self ._pinned_buffers [key ] = torch .empty (
212- v .shape , dtype = v .dtype , device = "cpu" , pin_memory = True
213- )
214- self ._pinned_buffers [key ].copy_ (v , non_blocking = True )
215- state [k ] = self ._pinned_buffers [key ]
216-
217- # Sync to ensure all copies are complete before freeing GPU memory
218- torch .cuda .synchronize ()
262+ state [k ] = v .cpu ()
219263
264+ # Sync and clear GPU memory
265+ torch .cuda .synchronize ()
220266 self ._is_offloaded = True
221267 gc_and_empty_cuda_cache ()
268+
269+ # Report free memory
270+ free_mem = torch .cuda .mem_get_info ()[0 ] / 1e9
271+ print (f"[UnslothService] Model offloaded. GPU memory free: { free_mem :.2f} GB" )
222272
223273 def reload_to_gpu (self , device : str = "cuda:0" ) -> None :
224- """Reload training model and optimizer back to GPU using async transfers ."""
274+ """Reload entire training model and optimizer back to GPU."""
225275 if not self ._is_offloaded :
226276 return
227277
228- # Reload model parameters from pinned memory (fast async transfer)
229- for name , param in self .peft_model .named_parameters ():
230- if param .device .type == "cpu" :
231- # Allocate on GPU and async copy from pinned memory
232- gpu_tensor = torch .empty (param .shape , dtype = param .dtype , device = device )
233- gpu_tensor .copy_ (param .data , non_blocking = True )
234- param .data = gpu_tensor
278+ print (f"[UnslothService] Reloading model to { device } ..." )
279+
280+ # Move the entire PEFT model back to GPU
281+ self .peft_model .to (device )
235282
236- # Reload optimizer state
283+ # Reload optimizer state to GPU
237284 optimizer = getattr (self .trainer , "optimizer" , None )
238285 if optimizer is not None and hasattr (optimizer , "state" ):
239286 for state in optimizer .state .values ():
240287 for k , v in state .items ():
241288 if isinstance (v , torch .Tensor ) and v .device .type == "cpu" :
242- gpu_tensor = torch .empty (v .shape , dtype = v .dtype , device = device )
243- gpu_tensor .copy_ (v , non_blocking = True )
244- state [k ] = gpu_tensor
289+ state [k ] = v .to (device )
245290
246291 # Sync to ensure all copies are complete before training
247292 torch .cuda .synchronize ()
248-
249293 self ._is_offloaded = False
294+ print (f"[UnslothService] Model reloaded to { device } " )
250295
251296
252297# ============================================================================
@@ -260,6 +305,7 @@ class UnslothService:
260305 base_model : str
261306 config : dev .InternalModelConfig
262307 output_dir : str
308+ device_config : DeviceConfig = field (default_factory = DeviceConfig )
263309 _is_sleeping : bool = False
264310 _latest_step : int = 0
265311 _lora_id_counter : int = 1 # Start from 1 since 0 is reserved
@@ -283,8 +329,13 @@ async def start_openai_server(
283329 # Extract step from checkpoint path
284330 self ._latest_step = get_step_from_dir (self .output_dir )
285331
286- # Offload training model to CPU before vLLM starts to free GPU memory
332+ # Offload training model to CPU so vLLM can use the GPU
287333 self ._state .offload_to_cpu ()
334+ # Force garbage collection and clear CUDA cache
335+ import gc
336+ gc .collect ()
337+ torch .cuda .empty_cache ()
338+ torch .cuda .synchronize ()
288339
289340 server_config = dev .get_openai_server_config (
290341 model_name = self .model_name ,
@@ -334,7 +385,7 @@ async def train(
334385 ) -> AsyncIterator [dict [str , float ]]:
335386 llm = await self .llm
336387
337- # Pause generation to prevent new requests during training
388+ # Time-sharing mode: pause vLLM, free GPU memory, then train
338389 await llm .pause_generation ()
339390
340391 # Determine sleep level based on outstanding requests:
@@ -364,10 +415,14 @@ async def train(
364415
365416 # If we haven't already, start the training task
366417 if not hasattr (self , "_train_task" ) or self ._train_task is None :
418+ # Use remapped device index: in split mode with CUDA_VISIBLE_DEVICES=0,1,
419+ # training is cuda:1 (second visible device)
420+ # Training device is cuda:0
367421 self ._train_task = asyncio .create_task (
368422 train (
369423 trainer = self ._state .trainer ,
370424 results_queue = self ._state .results_queue ,
425+ training_device = 0 ,
371426 )
372427 )
373428 warmup = True
@@ -396,7 +451,7 @@ async def train(
396451 verbose = verbose ,
397452 )
398453
399- # Offload training model to CPU before waking vLLM
454+ # Offload training model before waking vLLM
400455 self ._state .offload_to_cpu ()
401456
402457 # Free memory before waking up vLLM
@@ -438,18 +493,32 @@ async def train(
438493 def _state (self ) -> UnslothState :
439494 import unsloth
440495
496+ # Use cuda:0 for training - Unsloth's compiled code expects this
497+ # Time-sharing with vLLM via sleep/wake handles memory management
498+ cuda_device_index = 0
499+ torch .cuda .set_device (cuda_device_index )
500+ print (f"[UnslothService] Loading training model on cuda:{ cuda_device_index } " )
501+
441502 # Initialize Unsloth model
442503 init_args = self .config .get ("init_args" , {})
443504 checkpoint_dir = get_last_checkpoint_dir (self .output_dir )
444505 if checkpoint_dir :
445506 init_args ["model_name" ] = checkpoint_dir
446507 else :
447508 init_args ["model_name" ] = self .base_model
509+
510+ # Set device_map to cuda:0 - Unsloth expects training on cuda:0
511+ if "device_map" not in init_args :
512+ init_args ["device_map" ] = {"" : 0 }
448513
449514 model , tokenizer = cast (
450515 tuple [CausalLM , PreTrainedTokenizerBase ],
451516 unsloth .FastLanguageModel .from_pretrained (** init_args ),
452517 )
518+
519+ # Verify the model is on the correct device
520+ model_device = next (model .parameters ()).device
521+ print (f"[UnslothService] Model loaded on device: { model_device } , current_device={ torch .cuda .current_device ()} " )
453522
454523 # Initialize PEFT model - skip if already a PeftModel (e.g. loaded from checkpoint)
455524 if (
@@ -466,6 +535,56 @@ def _state(self) -> UnslothState:
466535 ),
467536 )
468537
538+ # Reset AcceleratorState singleton and patch device check before creating trainer
539+ # This is necessary because AcceleratorState caches the device from first initialization,
540+ # which might have been device 0 (from vLLM or imports). We need it to use device 1.
541+ try :
542+ from accelerate .state import AcceleratorState
543+ from accelerate import Accelerator
544+ AcceleratorState ._reset_state ()
545+
546+ # Monkey-patch Accelerator to skip device check for 4-bit models
547+ # The check fails when model is on GPU 1 but Accelerator was initialized earlier
548+ # We need to bypass the check BEFORE original_prepare_model runs
549+ original_prepare_model = Accelerator .prepare_model
550+ def patched_prepare_model (self , model , device_placement = None , evaluation_mode = False ):
551+ # For quantized models, temporarily remove the quantization flags to bypass the check
552+ # Then restore them after prepare_model completes
553+ was_8bit = getattr (model , "is_loaded_in_8bit" , False )
554+ was_4bit = getattr (model , "is_loaded_in_4bit" , False )
555+ was_device_map = getattr (model , "hf_device_map" , None )
556+
557+ if was_8bit or was_4bit :
558+ print (f"[UnslothService] Temporarily hiding quantization flags to bypass device check" )
559+ # Temporarily hide the quantization flags
560+ model .is_loaded_in_8bit = False
561+ model .is_loaded_in_4bit = False
562+ # Try to delete hf_device_map - it may be on inner model (accessible via __getattr__)
563+ # but not directly deletable from the PEFT wrapper
564+ try :
565+ delattr (model , "hf_device_map" )
566+ except AttributeError :
567+ pass # Attribute is on inner model, not directly on PEFT wrapper
568+
569+ try :
570+ result = original_prepare_model (self , model , device_placement , evaluation_mode )
571+ finally :
572+ # Restore the flags
573+ if was_8bit :
574+ model .is_loaded_in_8bit = True
575+ if was_4bit :
576+ model .is_loaded_in_4bit = True
577+ if was_device_map is not None :
578+ model .hf_device_map = was_device_map
579+ return result
580+ else :
581+ return original_prepare_model (self , model , device_placement , evaluation_mode )
582+ Accelerator .prepare_model = patched_prepare_model
583+
584+ print (f"[UnslothService] Reset AcceleratorState and patched prepare_model, current_device={ torch .cuda .current_device ()} " )
585+ except Exception as e :
586+ print (f"[UnslothService] Could not reset AcceleratorState: { e } " )
587+
469588 # Initialize trainer with dummy dataset
470589 data = {"prompt" : "" }
471590 trainer = GRPOTrainer (
@@ -504,12 +623,29 @@ async def get_inputs() -> TrainInputs:
504623
505624 @cached_property
506625 def llm (self ) -> asyncio .Task [AsyncLLM ]:
626+ # Use single GPU (cuda:0) for both vLLM and Unsloth with time-sharing
627+ # Unsloth's compiled training loop expects cuda:0, so split-GPU mode is not supported
628+ inference_gpu = self .device_config .inference_device
629+ os .environ ["CUDA_VISIBLE_DEVICES" ] = str (inference_gpu )
630+ print (f"[UnslothService] Starting vLLM on GPU { inference_gpu } (time-sharing mode with Unsloth)" )
631+
507632 # Filter engine args to remove incompatible boolean flags
508633 engine_args = {
509634 ** self .config .get ("engine_args" , {}),
510635 "enable_lora" : True ,
511636 "max_loras" : self .config .get ("engine_args" , {}).get ("max_loras" , 2 ),
512637 }
638+
639+ # In split mode, vLLM has the full GPU to itself, so use high utilization
640+ # In shared mode, use lower utilization to leave room for training model
641+ if self .device_config .is_split_mode :
642+ if "gpu_memory_utilization" not in engine_args :
643+ engine_args ["gpu_memory_utilization" ] = 0.90
644+ else :
645+ # Shared mode: lower utilization to coexist with training
646+ if "gpu_memory_utilization" not in engine_args :
647+ engine_args ["gpu_memory_utilization" ] = 0.80
648+
513649 # Remove boolean flags that vLLM's argparse doesn't accept as =False
514650 for key in ["enable_log_requests" , "disable_log_requests" ]:
515651 engine_args .pop (key , None )
0 commit comments