@@ -28,102 +28,161 @@ class ShotRS2VPipeline(ShotPipeline): # type:ignore
2828 def __init__ (self , clip_configs ):
2929 super ().__init__ (clip_configs )
3030
31+ @staticmethod
32+ def _parse_audio_path (audio_path ):
33+ if os .path .isdir (audio_path ):
34+ audio_config_path = os .path .join (audio_path , "config.json" )
35+ assert os .path .exists (audio_config_path ), "config.json not found in audio_path"
36+ with open (audio_config_path , "r" ) as f :
37+ audio_config = json .load (f )
38+ audio_files = [os .path .join (audio_path , obj ["audio" ]) for obj in audio_config ["talk_objects" ]]
39+ mask_files = [os .path .join (audio_path , obj ["mask" ]) for obj in audio_config ["talk_objects" ]]
40+ else :
41+ audio_files = [audio_path ]
42+ mask_files = None
43+ return audio_files , mask_files
44+
45+ @staticmethod
46+ def _load_single_audio (audio_path , target_sr ):
47+ arr , ori_sr = ta .load (audio_path )
48+ arr = arr .mean (0 )
49+ if ori_sr != target_sr :
50+ arr = ta .functional .resample (arr , ori_sr , target_sr )
51+ return arr
52+
53+ @classmethod
54+ def _load_audio_array (cls , audio_files , audio_sr , video_duration ):
55+ if len (audio_files ) == 1 :
56+ audio_array = cls ._load_single_audio (audio_files [0 ], audio_sr ).unsqueeze (0 )
57+ else :
58+ arrays = [cls ._load_single_audio (f , audio_sr ) for f in audio_files ]
59+ max_len = max (a .numel () for a in arrays )
60+ audio_array = torch .zeros (len (arrays ), max_len , dtype = torch .float32 )
61+ for i , arr in enumerate (arrays ):
62+ audio_array [i , : arr .numel ()] = arr
63+
64+ if video_duration is not None and video_duration > 0 :
65+ max_samples = int (video_duration * audio_sr )
66+ if audio_array .shape [1 ] > max_samples :
67+ audio_array = audio_array [:, :max_samples ]
68+
69+ return audio_array
70+
71+ @staticmethod
72+ def _load_mask_latents (rs2v , mask_files ):
73+ if mask_files is None :
74+ return None
75+ latents = [rs2v .process_single_mask (f ) for f in mask_files ]
76+ return torch .cat (latents , dim = 0 )
77+
78+ @staticmethod
79+ def _calc_total_clips (total_samples , audio_per_frame , target_video_length ):
80+ total_frames = int (np .ceil (total_samples / audio_per_frame ))
81+ if total_frames <= target_video_length :
82+ return 1
83+ remaining = total_frames - target_video_length
84+ return 1 + int (np .ceil (remaining / (target_video_length + 3 )))
85+
86+ @staticmethod
87+ def _update_latent_shape (clip_input_info , target_len , vae_stride ):
88+ if hasattr (clip_input_info , "latent_shape" ) and clip_input_info .latent_shape is not None :
89+ s = clip_input_info .latent_shape
90+ new_t = (target_len - 1 ) // vae_stride + 1
91+ clip_input_info .latent_shape = [s [0 ], new_t , s [2 ], s [3 ]]
92+
93+ def _compute_segment_params (self , idx , audio_clip , pad_len , target_video_length , target_fps , audio_per_frame , vae_stride , clip_input_info ):
94+ """Compute per-segment parameters (target_video_length, latent_shape, trimmed audio_clip).
95+
96+ Returns:
97+ (is_first, is_last, segment_actual_video_frames, audio_clip)
98+ """
99+ is_first = idx == 0
100+ is_last = pad_len > 0
101+ segment_actual_video_frames = None
102+
103+ if is_last :
104+ actual_audio_samples = audio_clip .shape [1 ] - pad_len
105+ actual_video_frames = int (np .ceil (actual_audio_samples / audio_per_frame ))
106+ segment_actual_video_frames = actual_video_frames
107+
108+ seg_target_len = calculate_target_video_length_from_duration (actual_video_frames / target_fps , target_fps )
109+ clip_input_info .target_video_length = seg_target_len
110+ self ._update_latent_shape (clip_input_info , seg_target_len , vae_stride )
111+
112+ logger .info (
113+ f"Segment { idx } : Last segment with pad_len={ pad_len } , "
114+ f"actual_video_frames={ actual_video_frames } , "
115+ f"calculated target_video_length={ seg_target_len } , "
116+ f"latent_shape={ clip_input_info .latent_shape } "
117+ )
118+ audio_clip = audio_clip [:, : clip_input_info .target_video_length * audio_per_frame ]
119+ else :
120+ cur_clip_len = target_video_length if is_first else (target_video_length + 3 )
121+ clip_input_info .target_video_length = cur_clip_len
122+ if not is_first :
123+ self ._update_latent_shape (clip_input_info , cur_clip_len , vae_stride )
124+
125+ return is_first , is_last , segment_actual_video_frames , audio_clip
126+
127+ @staticmethod
128+ def _trim_segment (gen_clip_video , audio_clip , segment_actual_video_frames , audio_per_frame ):
129+ if segment_actual_video_frames is not None :
130+ video_seg = gen_clip_video [:, :, :segment_actual_video_frames ]
131+ audio_seg = audio_clip [:, : segment_actual_video_frames * audio_per_frame ].sum (dim = 0 )
132+ else :
133+ video_seg = gen_clip_video
134+ audio_seg = audio_clip .sum (dim = 0 )
135+ return video_seg , audio_seg
136+
137+ @staticmethod
138+ def _merge_and_save (gen_video_list , cut_audio_list , target_fps , save_result_path ):
139+ gen_lvideo = torch .cat (gen_video_list , dim = 2 ).float ()
140+ gen_lvideo = torch .clamp (gen_lvideo , - 1 , 1 )
141+ merge_audio = torch .cat (cut_audio_list , dim = 0 ).numpy ().astype (np .float32 )
142+
143+ if is_main_process () and save_result_path :
144+ out_path = os .path .join ("./" , "video_merge.mp4" )
145+ audio_file = os .path .join ("./" , "audio_merge.wav" )
146+ save_to_video (gen_lvideo , out_path , target_fps )
147+ save_audio (merge_audio , audio_file , out_path , output_path = save_result_path )
148+ os .remove (out_path )
149+ os .remove (audio_file )
150+
151+ return gen_lvideo , merge_audio
152+
31153 @torch .no_grad ()
32154 def generate (self , args ):
33155 rs2v = self .clip_generators ["rs2v_clip" ]
34- # 获取此clip模型的配置信息
156+
35157 target_fps = rs2v .config .get ("target_fps" , 16 )
36158 audio_sr = rs2v .config .get ("audio_sr" , 16000 )
37159 audio_per_frame = audio_sr // target_fps
160+ vae_stride = rs2v .config ["vae_stride" ][0 ]
38161
39- # 获取用户输入信息
40162 clip_input_info = init_input_info_from_args (rs2v .config ["task" ], args )
41- # 从默认配置中补全输入信息
42163 clip_input_info = self .check_input_info (clip_input_info , rs2v .config )
43- target_video_length = clip_input_info .target_video_length
44164
45- # Auto-calculate target_video_length from video_duration if not explicitly provided
46165 if clip_input_info .target_video_length is None or clip_input_info .target_video_length == UNSET :
47166 if clip_input_info .video_duration is not None and clip_input_info .video_duration != UNSET :
48- # Calculate for the first segment (max 5s)
49167 segment_duration = min (clip_input_info .video_duration , 5.0 )
50168 clip_input_info .target_video_length = calculate_target_video_length_from_duration (segment_duration , target_fps )
51169 logger .info (f"Auto-calculated target_video_length={ clip_input_info .target_video_length } from video_duration={ clip_input_info .video_duration } s (segment={ segment_duration } s)" )
52170 else :
53- # Fallback to config default
54171 clip_input_info .target_video_length = rs2v .config .get ("target_video_length" , 81 )
55172
56173 target_video_length = clip_input_info .target_video_length
174+ base_seed = clip_input_info .seed
57175
58- gen_video_list = []
59- cut_audio_list = []
60- video_duration = clip_input_info .video_duration
61-
62- def get_audio_files_from_audio_path (audio_path ):
63- if os .path .isdir (audio_path ):
64- audio_files = []
65- mask_files = []
66- audio_config_path = os .path .join (audio_path , "config.json" )
67- assert os .path .exists (audio_config_path ), "config.json not found in audio_path"
68- with open (audio_config_path , "r" ) as f :
69- audio_config = json .load (f )
70- for talk_object in audio_config ["talk_objects" ]:
71- audio_files .append (os .path .join (audio_path , talk_object ["audio" ]))
72- mask_files .append (os .path .join (audio_path , talk_object ["mask" ]))
73- else :
74- audio_files = [audio_path ]
75- mask_files = None
76- return audio_files , mask_files
77-
78- def load_audio (audio_path , target_sr ):
79- arr , ori_sr = ta .load (audio_path )
80- arr = arr .mean (0 )
81- if ori_sr != target_sr :
82- arr = ta .functional .resample (arr , ori_sr , target_sr )
83- return arr
84-
85- audio_files , mask_files = get_audio_files_from_audio_path (clip_input_info .audio_path )
176+ audio_files , mask_files = self ._parse_audio_path (clip_input_info .audio_path )
86177 clip_input_info .audio_num = len (audio_files )
87178
88- if len (audio_files ) == 1 :
89- audio_array = load_audio (audio_files [0 ], audio_sr )
90- audio_array = audio_array .unsqueeze (0 )
91- else :
92- audio_arrays = []
93- max_len = 0
94- for a_file in audio_files :
95- arr = load_audio (a_file , audio_sr )
96- audio_arrays .append (arr )
97- max_len = max (max_len , arr .numel ())
98- num_files = len (audio_arrays )
99- audio_array = torch .zeros (num_files , max_len , dtype = torch .float32 )
100- for i , arr in enumerate (audio_arrays ):
101- length = arr .numel ()
102- audio_array [i , :length ] = arr
103-
104- if video_duration is not None and video_duration > 0 :
105- max_samples = int (video_duration * audio_sr )
106- if audio_array .shape [1 ] > max_samples :
107- audio_array = audio_array [:, :max_samples ]
108-
109- if mask_files is not None :
110- mask_latents = [rs2v .process_single_mask (mask_file ) for mask_file in mask_files ]
111- person_mask_latens = torch .cat (mask_latents , dim = 0 )
112- else :
113- person_mask_latens = None
179+ audio_array = self ._load_audio_array (audio_files , audio_sr , clip_input_info .video_duration )
180+ person_mask_latens = self ._load_mask_latents (rs2v , mask_files )
114181
115182 audio_reader = RS2V_SlidingWindowReader (audio_array , first_clip_len = target_video_length , clip_len = target_video_length + 3 , sr = audio_sr , fps = target_fps )
183+ total_clips = self ._calc_total_clips (audio_array .shape [1 ], audio_per_frame , target_video_length )
184+ ref_state_seq = get_reference_state_sequence (target_video_length - 3 , target_fps )
116185
117- total_frames = int (np .ceil (audio_array .shape [1 ] / audio_per_frame ))
118- if total_frames <= target_video_length :
119- total_clips = 1
120- else :
121- remaining = total_frames - target_video_length
122- total_clips = 1 + int (np .ceil (remaining / (target_video_length + 3 )))
123-
124- ref_state_sq = get_reference_state_sequence (target_video_length - 3 , target_fps )
125-
126- # 预先运行输入编码的静态部分 (处理ref image的vae编码和文本编码)
127186 rs2v .input_info = clip_input_info
128187 rs2v .inputs_static = rs2v ._run_input_encoder_local_rs2v_static ()
129188
@@ -132,90 +191,33 @@ def load_audio(audio_path, target_sr):
132191 self .va_controller = VAController (rs2v )
133192 logger .info (f"init va_recorder: { self .va_controller .recorder } and va_reader: { self .va_controller .reader } " )
134193
135- idx = 0
136- while True :
137- audio_clip , pad_len = audio_reader .next_frame ()
194+ gen_video_list = []
195+ cut_audio_list = []
196+
197+ for idx , (audio_clip , pad_len ) in enumerate (iter (audio_reader .next_frame , (None , 0 ))):
138198 if audio_clip is None :
139199 break
200+ rs2v .check_stop ()
140201
141- is_first = True if idx == 0 else False
142- is_last = True if pad_len > 0 else False
143-
144- pipe = rs2v
145- pipe .check_stop ()
146-
147- # Calculate actual target_video_length for this segment based on audio length
148- segment_actual_video_frames = None # Track for last-segment trimming
149- if is_last and pad_len > 0 :
150- # For the last segment with padding, calculate actual video frames needed
151- actual_audio_samples = audio_clip .shape [1 ] - pad_len
152- actual_video_frames = int (np .ceil (actual_audio_samples / audio_per_frame ))
153- segment_actual_video_frames = actual_video_frames
154- # Apply the formula to ensure VAE stride constraint
155- segment_target_video_length = calculate_target_video_length_from_duration (actual_video_frames / target_fps , target_fps )
156- clip_input_info .target_video_length = segment_target_video_length
157-
158- # Recalculate latent_shape for this segment
159- # latent_shape = [C, T, H, W] where T = (target_video_length - 1) // vae_stride[0] + 1
160- if hasattr (clip_input_info , "latent_shape" ) and clip_input_info .latent_shape is not None :
161- original_latent_shape = clip_input_info .latent_shape
162- # Update only the temporal dimension (index 1)
163- new_latent_t = (segment_target_video_length - 1 ) // rs2v .config ["vae_stride" ][0 ] + 1
164- clip_input_info .latent_shape = [
165- original_latent_shape [0 ], # C
166- new_latent_t , # T
167- original_latent_shape [2 ], # H
168- original_latent_shape [3 ], # W
169- ]
170-
171- logger .info (
172- f"Segment { idx } : Last segment with pad_len={ pad_len } , "
173- f"actual_video_frames={ actual_video_frames } , "
174- f"calculated target_video_length={ segment_target_video_length } , "
175- f"latent_shape={ clip_input_info .latent_shape } "
176- )
177- else :
178- # For first/middle segments, use the original target_video_length
179- cur_clip_len = target_video_length if is_first else (target_video_length + 3 )
180- clip_input_info .target_video_length = cur_clip_len
181-
182- # Recalculate latent_shape for non-first segments
183- if not is_first and hasattr (clip_input_info , "latent_shape" ) and clip_input_info .latent_shape is not None :
184- original_latent_shape = clip_input_info .latent_shape
185- new_latent_t = (cur_clip_len - 1 ) // rs2v .config ["vae_stride" ][0 ] + 1
186- clip_input_info .latent_shape = [
187- original_latent_shape [0 ],
188- new_latent_t ,
189- original_latent_shape [2 ],
190- original_latent_shape [3 ],
191- ]
202+ is_first , is_last , segment_actual_frames , audio_clip = self ._compute_segment_params (idx , audio_clip , pad_len , target_video_length , target_fps , audio_per_frame , vae_stride , clip_input_info )
192203
193204 clip_input_info .is_first = is_first
194205 clip_input_info .is_last = is_last
195- clip_input_info .ref_state = ref_state_sq [idx % len (ref_state_sq )]
196- clip_input_info .seed = clip_input_info . seed + idx
206+ clip_input_info .ref_state = ref_state_seq [idx % len (ref_state_seq )]
207+ clip_input_info .seed = base_seed + idx
197208 clip_input_info .audio_clip = audio_clip
198- idx = idx + 1
209+ clip_input_info .person_mask_latens = person_mask_latens
210+
199211 if self .progress_callback :
200- self .progress_callback (idx , total_clips )
212+ self .progress_callback (idx + 1 , total_clips )
201213
202214 rs2v .input_info = clip_input_info
203- clip_input_info .person_mask_latens = person_mask_latens
204-
205- # 使用动态输入获取当前 clip 控制参数
206215 rs2v .inputs = rs2v ._run_input_encoder_local_rs2v_dynamic ()
207-
208216 gen_clip_video , audio_clip , gen_latents = rs2v .run_clip_main ()
209- logger .info (f"Generated rs2v clip { idx } , pad_len { pad_len } , gen_clip_video shape: { gen_clip_video .shape } , audio_clip shape: { audio_clip .shape } gen_latents shape: { gen_latents .shape } " )
210217
211- if segment_actual_video_frames is not None :
212- # Last segment: trim to exact actual frames needed
213- video_seg = gen_clip_video [:, :, :segment_actual_video_frames ]
214- audio_seg = audio_clip [:, : segment_actual_video_frames * audio_per_frame ].sum (dim = 0 )
215- else :
216- video_seg = gen_clip_video
217- audio_seg = audio_clip .sum (dim = 0 )
218+ logger .info (f"Generated rs2v clip { idx + 1 } , pad_len={ pad_len } , gen_clip_video={ gen_clip_video .shape } , audio_clip={ audio_clip .shape } , gen_latents={ gen_latents .shape } " )
218219
220+ video_seg , audio_seg = self ._trim_segment (gen_clip_video , audio_clip , segment_actual_frames , audio_per_frame )
219221 clip_input_info .overlap_latent = gen_latents [:, - 1 :]
220222
221223 if clip_input_info .return_result_tensor or not clip_input_info .stream_save_video :
@@ -229,23 +231,10 @@ def load_audio(audio_path, target_sr):
229231 if not clip_input_info .return_result_tensor and clip_input_info .stream_save_video :
230232 return None , None , None
231233
232- gen_lvideo = torch .cat (gen_video_list , dim = 2 ).float ()
233- gen_lvideo = torch .clamp (gen_lvideo , - 1 , 1 )
234- merge_audio = torch .cat (cut_audio_list , dim = 0 ).numpy ().astype (np .float32 )
235-
236- if is_main_process () and clip_input_info .save_result_path :
237- out_path = os .path .join ("./" , "video_merge.mp4" )
238- audio_file = os .path .join ("./" , "audio_merge.wav" )
239-
240- save_to_video (gen_lvideo , out_path , 16 )
241- save_audio (merge_audio , audio_file , out_path , output_path = clip_input_info .save_result_path )
242- os .remove (out_path )
243- os .remove (audio_file )
244-
234+ gen_lvideo , merge_audio = self ._merge_and_save (gen_video_list , cut_audio_list , target_fps , clip_input_info .save_result_path )
245235 return gen_lvideo , merge_audio , audio_sr
246236
247237 def run_pipeline (self , input_info ):
248- # input_info = self.update_input_info(input_info)
249238 try :
250239 gen_lvideo , merge_audio , audio_sr = self .generate (input_info )
251240 finally :
0 commit comments