Skip to content

Commit edb0c2c

Browse files
committed
feat: Enhance last segment handling in ShotRS2VPipeline with precise video and audio trimming
1 parent f766d71 commit edb0c2c

1 file changed

Lines changed: 12 additions & 6 deletions

File tree

lightx2v/shot_runner/rs2v_infer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,12 @@ def load_audio(audio_path, target_sr):
145145
pipe.check_stop()
146146

147147
# Calculate actual target_video_length for this segment based on audio length
148+
segment_actual_video_frames = None # Track for last-segment trimming
148149
if is_last and pad_len > 0:
149150
# For the last segment with padding, calculate actual video frames needed
150151
actual_audio_samples = audio_clip.shape[1] - pad_len
151152
actual_video_frames = int(np.ceil(actual_audio_samples / audio_per_frame))
153+
segment_actual_video_frames = actual_video_frames
152154
# Apply the formula to ensure VAE stride constraint
153155
segment_target_video_length = calculate_target_video_length_from_duration(actual_video_frames / target_fps, target_fps)
154156
clip_input_info.target_video_length = segment_target_video_length
@@ -206,15 +208,19 @@ def load_audio(audio_path, target_sr):
206208
gen_clip_video, audio_clip, gen_latents = rs2v.run_clip_main()
207209
logger.info(f"Generated rs2v clip {idx}, pad_len {pad_len}, gen_clip_video shape: {gen_clip_video.shape}, audio_clip shape: {audio_clip.shape} gen_latents shape: {gen_latents.shape}")
208210

209-
video_pad_len = pad_len // audio_per_frame
210-
audio_pad_len = video_pad_len * audio_per_frame
211-
video_seg = gen_clip_video[:, :, : gen_clip_video.shape[2] - video_pad_len]
212-
# Since audio_clip is now multidimensional (N, T), slice on dim 1 and sum on dim 0 to merge tracks
213-
audio_seg = audio_clip[:, : audio_clip.shape[1] - audio_pad_len].sum(dim=0)
211+
if segment_actual_video_frames is not None:
212+
# Last segment: trim to exact actual frames needed
213+
video_seg = gen_clip_video[:, :, :segment_actual_video_frames]
214+
audio_seg = audio_clip[:, : segment_actual_video_frames * audio_per_frame].sum(dim=0)
215+
else:
216+
video_pad_len = pad_len // audio_per_frame
217+
audio_pad_len = video_pad_len * audio_per_frame
218+
video_seg = gen_clip_video[:, :, : gen_clip_video.shape[2] - video_pad_len]
219+
audio_seg = audio_clip[:, : audio_clip.shape[1] - audio_pad_len].sum(dim=0)
214220
clip_input_info.overlap_latent = gen_latents[:, -1:]
215221

216222
if clip_input_info.return_result_tensor or not clip_input_info.stream_save_video:
217-
gen_video_list.append(video_seg.clone().cpu().float())
223+
gen_video_list.append(video_seg.clone().cpu())
218224
cut_audio_list.append(audio_seg.cpu())
219225
elif self.va_controller.recorder is not None:
220226
video_seg = torch.clamp(video_seg, -1, 1).to(torch.float).cpu()

0 commit comments

Comments
 (0)