@@ -73,7 +73,10 @@ def offload_to_cpu(self) -> None:
7373 if optimizer is not None and hasattr (optimizer , "state" ):
7474 for param_id , state in optimizer .state .items ():
7575 for key , value in state .items ():
76- if not isinstance (value , torch .Tensor ) or value .device .type != "cuda" :
76+ if (
77+ not isinstance (value , torch .Tensor )
78+ or value .device .type != "cuda"
79+ ):
7780 continue
7881 buffer_key = f"opt_{ id (param_id )} _{ key } "
7982 if (
@@ -108,9 +111,14 @@ def reload_to_gpu(self, device: str = "cuda:0") -> None:
108111 if optimizer is not None and hasattr (optimizer , "state" ):
109112 for state in optimizer .state .values ():
110113 for key , value in state .items ():
111- if not isinstance (value , torch .Tensor ) or value .device .type != "cpu" :
114+ if (
115+ not isinstance (value , torch .Tensor )
116+ or value .device .type != "cpu"
117+ ):
112118 continue
113- gpu_tensor = torch .empty (value .shape , dtype = value .dtype , device = device )
119+ gpu_tensor = torch .empty (
120+ value .shape , dtype = value .dtype , device = device
121+ )
114122 gpu_tensor .copy_ (value , non_blocking = True )
115123 state [key ] = gpu_tensor
116124
@@ -224,7 +232,10 @@ def create_unsloth_train_context(
224232 loader_cls .from_pretrained (** init_args ),
225233 )
226234
227- if hasattr (model , "peft_config" ) and getattr (model , "peft_config" , None ) is not None :
235+ if (
236+ hasattr (model , "peft_config" )
237+ and getattr (model , "peft_config" , None ) is not None
238+ ):
228239 peft_model = cast (peft .peft_model .PeftModelForCausalLM , model )
229240 else :
230241 peft_model = cast (
@@ -301,7 +312,9 @@ def _precalculate_new_logprobs(
301312 if isinstance (value , torch .Tensor )
302313 },
303314 pixel_values = packed_tensors ["pixel_values" ][offset : offset + 1 ],
304- image_grid_thw = packed_tensors ["image_grid_thw" ][offset : offset + 1 ],
315+ image_grid_thw = packed_tensors ["image_grid_thw" ][
316+ offset : offset + 1
317+ ],
305318 config = config ,
306319 _config = _config ,
307320 return_new_logprobs = True ,
0 commit comments