Skip to content

Commit ed998e0

Browse files
committed
fix(rs2v_infer): modularize and optimize audio and video segment processing
1 parent f4c5184 commit ed998e0

1 file changed

Lines changed: 144 additions & 155 deletions

File tree

lightx2v/shot_runner/rs2v_infer.py

Lines changed: 144 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -28,102 +28,161 @@ class ShotRS2VPipeline(ShotPipeline): # type:ignore
2828
def __init__(self, clip_configs):
2929
super().__init__(clip_configs)
3030

31+
@staticmethod
32+
def _parse_audio_path(audio_path):
33+
if os.path.isdir(audio_path):
34+
audio_config_path = os.path.join(audio_path, "config.json")
35+
assert os.path.exists(audio_config_path), "config.json not found in audio_path"
36+
with open(audio_config_path, "r") as f:
37+
audio_config = json.load(f)
38+
audio_files = [os.path.join(audio_path, obj["audio"]) for obj in audio_config["talk_objects"]]
39+
mask_files = [os.path.join(audio_path, obj["mask"]) for obj in audio_config["talk_objects"]]
40+
else:
41+
audio_files = [audio_path]
42+
mask_files = None
43+
return audio_files, mask_files
44+
45+
@staticmethod
46+
def _load_single_audio(audio_path, target_sr):
47+
arr, ori_sr = ta.load(audio_path)
48+
arr = arr.mean(0)
49+
if ori_sr != target_sr:
50+
arr = ta.functional.resample(arr, ori_sr, target_sr)
51+
return arr
52+
53+
@classmethod
54+
def _load_audio_array(cls, audio_files, audio_sr, video_duration):
55+
if len(audio_files) == 1:
56+
audio_array = cls._load_single_audio(audio_files[0], audio_sr).unsqueeze(0)
57+
else:
58+
arrays = [cls._load_single_audio(f, audio_sr) for f in audio_files]
59+
max_len = max(a.numel() for a in arrays)
60+
audio_array = torch.zeros(len(arrays), max_len, dtype=torch.float32)
61+
for i, arr in enumerate(arrays):
62+
audio_array[i, : arr.numel()] = arr
63+
64+
if video_duration is not None and video_duration > 0:
65+
max_samples = int(video_duration * audio_sr)
66+
if audio_array.shape[1] > max_samples:
67+
audio_array = audio_array[:, :max_samples]
68+
69+
return audio_array
70+
71+
@staticmethod
72+
def _load_mask_latents(rs2v, mask_files):
73+
if mask_files is None:
74+
return None
75+
latents = [rs2v.process_single_mask(f) for f in mask_files]
76+
return torch.cat(latents, dim=0)
77+
78+
@staticmethod
79+
def _calc_total_clips(total_samples, audio_per_frame, target_video_length):
80+
total_frames = int(np.ceil(total_samples / audio_per_frame))
81+
if total_frames <= target_video_length:
82+
return 1
83+
remaining = total_frames - target_video_length
84+
return 1 + int(np.ceil(remaining / (target_video_length + 3)))
85+
86+
@staticmethod
87+
def _update_latent_shape(clip_input_info, target_len, vae_stride):
88+
if hasattr(clip_input_info, "latent_shape") and clip_input_info.latent_shape is not None:
89+
s = clip_input_info.latent_shape
90+
new_t = (target_len - 1) // vae_stride + 1
91+
clip_input_info.latent_shape = [s[0], new_t, s[2], s[3]]
92+
93+
def _compute_segment_params(self, idx, audio_clip, pad_len, target_video_length, target_fps, audio_per_frame, vae_stride, clip_input_info):
94+
"""Compute per-segment parameters (target_video_length, latent_shape, trimmed audio_clip).
95+
96+
Returns:
97+
(is_first, is_last, segment_actual_video_frames, audio_clip)
98+
"""
99+
is_first = idx == 0
100+
is_last = pad_len > 0
101+
segment_actual_video_frames = None
102+
103+
if is_last:
104+
actual_audio_samples = audio_clip.shape[1] - pad_len
105+
actual_video_frames = int(np.ceil(actual_audio_samples / audio_per_frame))
106+
segment_actual_video_frames = actual_video_frames
107+
108+
seg_target_len = calculate_target_video_length_from_duration(actual_video_frames / target_fps, target_fps)
109+
clip_input_info.target_video_length = seg_target_len
110+
self._update_latent_shape(clip_input_info, seg_target_len, vae_stride)
111+
112+
logger.info(
113+
f"Segment {idx}: Last segment with pad_len={pad_len}, "
114+
f"actual_video_frames={actual_video_frames}, "
115+
f"calculated target_video_length={seg_target_len}, "
116+
f"latent_shape={clip_input_info.latent_shape}"
117+
)
118+
audio_clip = audio_clip[:, : clip_input_info.target_video_length * audio_per_frame]
119+
else:
120+
cur_clip_len = target_video_length if is_first else (target_video_length + 3)
121+
clip_input_info.target_video_length = cur_clip_len
122+
if not is_first:
123+
self._update_latent_shape(clip_input_info, cur_clip_len, vae_stride)
124+
125+
return is_first, is_last, segment_actual_video_frames, audio_clip
126+
127+
@staticmethod
128+
def _trim_segment(gen_clip_video, audio_clip, segment_actual_video_frames, audio_per_frame):
129+
if segment_actual_video_frames is not None:
130+
video_seg = gen_clip_video[:, :, :segment_actual_video_frames]
131+
audio_seg = audio_clip[:, : segment_actual_video_frames * audio_per_frame].sum(dim=0)
132+
else:
133+
video_seg = gen_clip_video
134+
audio_seg = audio_clip.sum(dim=0)
135+
return video_seg, audio_seg
136+
137+
@staticmethod
138+
def _merge_and_save(gen_video_list, cut_audio_list, target_fps, save_result_path):
139+
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
140+
gen_lvideo = torch.clamp(gen_lvideo, -1, 1)
141+
merge_audio = torch.cat(cut_audio_list, dim=0).numpy().astype(np.float32)
142+
143+
if is_main_process() and save_result_path:
144+
out_path = os.path.join("./", "video_merge.mp4")
145+
audio_file = os.path.join("./", "audio_merge.wav")
146+
save_to_video(gen_lvideo, out_path, target_fps)
147+
save_audio(merge_audio, audio_file, out_path, output_path=save_result_path)
148+
os.remove(out_path)
149+
os.remove(audio_file)
150+
151+
return gen_lvideo, merge_audio
152+
31153
@torch.no_grad()
32154
def generate(self, args):
33155
rs2v = self.clip_generators["rs2v_clip"]
34-
# 获取此clip模型的配置信息
156+
35157
target_fps = rs2v.config.get("target_fps", 16)
36158
audio_sr = rs2v.config.get("audio_sr", 16000)
37159
audio_per_frame = audio_sr // target_fps
160+
vae_stride = rs2v.config["vae_stride"][0]
38161

39-
# 获取用户输入信息
40162
clip_input_info = init_input_info_from_args(rs2v.config["task"], args)
41-
# 从默认配置中补全输入信息
42163
clip_input_info = self.check_input_info(clip_input_info, rs2v.config)
43-
target_video_length = clip_input_info.target_video_length
44164

45-
# Auto-calculate target_video_length from video_duration if not explicitly provided
46165
if clip_input_info.target_video_length is None or clip_input_info.target_video_length == UNSET:
47166
if clip_input_info.video_duration is not None and clip_input_info.video_duration != UNSET:
48-
# Calculate for the first segment (max 5s)
49167
segment_duration = min(clip_input_info.video_duration, 5.0)
50168
clip_input_info.target_video_length = calculate_target_video_length_from_duration(segment_duration, target_fps)
51169
logger.info(f"Auto-calculated target_video_length={clip_input_info.target_video_length} from video_duration={clip_input_info.video_duration}s (segment={segment_duration}s)")
52170
else:
53-
# Fallback to config default
54171
clip_input_info.target_video_length = rs2v.config.get("target_video_length", 81)
55172

56173
target_video_length = clip_input_info.target_video_length
174+
base_seed = clip_input_info.seed
57175

58-
gen_video_list = []
59-
cut_audio_list = []
60-
video_duration = clip_input_info.video_duration
61-
62-
def get_audio_files_from_audio_path(audio_path):
63-
if os.path.isdir(audio_path):
64-
audio_files = []
65-
mask_files = []
66-
audio_config_path = os.path.join(audio_path, "config.json")
67-
assert os.path.exists(audio_config_path), "config.json not found in audio_path"
68-
with open(audio_config_path, "r") as f:
69-
audio_config = json.load(f)
70-
for talk_object in audio_config["talk_objects"]:
71-
audio_files.append(os.path.join(audio_path, talk_object["audio"]))
72-
mask_files.append(os.path.join(audio_path, talk_object["mask"]))
73-
else:
74-
audio_files = [audio_path]
75-
mask_files = None
76-
return audio_files, mask_files
77-
78-
def load_audio(audio_path, target_sr):
79-
arr, ori_sr = ta.load(audio_path)
80-
arr = arr.mean(0)
81-
if ori_sr != target_sr:
82-
arr = ta.functional.resample(arr, ori_sr, target_sr)
83-
return arr
84-
85-
audio_files, mask_files = get_audio_files_from_audio_path(clip_input_info.audio_path)
176+
audio_files, mask_files = self._parse_audio_path(clip_input_info.audio_path)
86177
clip_input_info.audio_num = len(audio_files)
87178

88-
if len(audio_files) == 1:
89-
audio_array = load_audio(audio_files[0], audio_sr)
90-
audio_array = audio_array.unsqueeze(0)
91-
else:
92-
audio_arrays = []
93-
max_len = 0
94-
for a_file in audio_files:
95-
arr = load_audio(a_file, audio_sr)
96-
audio_arrays.append(arr)
97-
max_len = max(max_len, arr.numel())
98-
num_files = len(audio_arrays)
99-
audio_array = torch.zeros(num_files, max_len, dtype=torch.float32)
100-
for i, arr in enumerate(audio_arrays):
101-
length = arr.numel()
102-
audio_array[i, :length] = arr
103-
104-
if video_duration is not None and video_duration > 0:
105-
max_samples = int(video_duration * audio_sr)
106-
if audio_array.shape[1] > max_samples:
107-
audio_array = audio_array[:, :max_samples]
108-
109-
if mask_files is not None:
110-
mask_latents = [rs2v.process_single_mask(mask_file) for mask_file in mask_files]
111-
person_mask_latens = torch.cat(mask_latents, dim=0)
112-
else:
113-
person_mask_latens = None
179+
audio_array = self._load_audio_array(audio_files, audio_sr, clip_input_info.video_duration)
180+
person_mask_latens = self._load_mask_latents(rs2v, mask_files)
114181

115182
audio_reader = RS2V_SlidingWindowReader(audio_array, first_clip_len=target_video_length, clip_len=target_video_length + 3, sr=audio_sr, fps=target_fps)
183+
total_clips = self._calc_total_clips(audio_array.shape[1], audio_per_frame, target_video_length)
184+
ref_state_seq = get_reference_state_sequence(target_video_length - 3, target_fps)
116185

117-
total_frames = int(np.ceil(audio_array.shape[1] / audio_per_frame))
118-
if total_frames <= target_video_length:
119-
total_clips = 1
120-
else:
121-
remaining = total_frames - target_video_length
122-
total_clips = 1 + int(np.ceil(remaining / (target_video_length + 3)))
123-
124-
ref_state_sq = get_reference_state_sequence(target_video_length - 3, target_fps)
125-
126-
# 预先运行输入编码的静态部分 (处理ref image的vae编码和文本编码)
127186
rs2v.input_info = clip_input_info
128187
rs2v.inputs_static = rs2v._run_input_encoder_local_rs2v_static()
129188

@@ -132,90 +191,33 @@ def load_audio(audio_path, target_sr):
132191
self.va_controller = VAController(rs2v)
133192
logger.info(f"init va_recorder: {self.va_controller.recorder} and va_reader: {self.va_controller.reader}")
134193

135-
idx = 0
136-
while True:
137-
audio_clip, pad_len = audio_reader.next_frame()
194+
gen_video_list = []
195+
cut_audio_list = []
196+
197+
for idx, (audio_clip, pad_len) in enumerate(iter(audio_reader.next_frame, (None, 0))):
138198
if audio_clip is None:
139199
break
200+
rs2v.check_stop()
140201

141-
is_first = True if idx == 0 else False
142-
is_last = True if pad_len > 0 else False
143-
144-
pipe = rs2v
145-
pipe.check_stop()
146-
147-
# Calculate actual target_video_length for this segment based on audio length
148-
segment_actual_video_frames = None # Track for last-segment trimming
149-
if is_last and pad_len > 0:
150-
# For the last segment with padding, calculate actual video frames needed
151-
actual_audio_samples = audio_clip.shape[1] - pad_len
152-
actual_video_frames = int(np.ceil(actual_audio_samples / audio_per_frame))
153-
segment_actual_video_frames = actual_video_frames
154-
# Apply the formula to ensure VAE stride constraint
155-
segment_target_video_length = calculate_target_video_length_from_duration(actual_video_frames / target_fps, target_fps)
156-
clip_input_info.target_video_length = segment_target_video_length
157-
158-
# Recalculate latent_shape for this segment
159-
# latent_shape = [C, T, H, W] where T = (target_video_length - 1) // vae_stride[0] + 1
160-
if hasattr(clip_input_info, "latent_shape") and clip_input_info.latent_shape is not None:
161-
original_latent_shape = clip_input_info.latent_shape
162-
# Update only the temporal dimension (index 1)
163-
new_latent_t = (segment_target_video_length - 1) // rs2v.config["vae_stride"][0] + 1
164-
clip_input_info.latent_shape = [
165-
original_latent_shape[0], # C
166-
new_latent_t, # T
167-
original_latent_shape[2], # H
168-
original_latent_shape[3], # W
169-
]
170-
171-
logger.info(
172-
f"Segment {idx}: Last segment with pad_len={pad_len}, "
173-
f"actual_video_frames={actual_video_frames}, "
174-
f"calculated target_video_length={segment_target_video_length}, "
175-
f"latent_shape={clip_input_info.latent_shape}"
176-
)
177-
else:
178-
# For first/middle segments, use the original target_video_length
179-
cur_clip_len = target_video_length if is_first else (target_video_length + 3)
180-
clip_input_info.target_video_length = cur_clip_len
181-
182-
# Recalculate latent_shape for non-first segments
183-
if not is_first and hasattr(clip_input_info, "latent_shape") and clip_input_info.latent_shape is not None:
184-
original_latent_shape = clip_input_info.latent_shape
185-
new_latent_t = (cur_clip_len - 1) // rs2v.config["vae_stride"][0] + 1
186-
clip_input_info.latent_shape = [
187-
original_latent_shape[0],
188-
new_latent_t,
189-
original_latent_shape[2],
190-
original_latent_shape[3],
191-
]
202+
is_first, is_last, segment_actual_frames, audio_clip = self._compute_segment_params(idx, audio_clip, pad_len, target_video_length, target_fps, audio_per_frame, vae_stride, clip_input_info)
192203

193204
clip_input_info.is_first = is_first
194205
clip_input_info.is_last = is_last
195-
clip_input_info.ref_state = ref_state_sq[idx % len(ref_state_sq)]
196-
clip_input_info.seed = clip_input_info.seed + idx
206+
clip_input_info.ref_state = ref_state_seq[idx % len(ref_state_seq)]
207+
clip_input_info.seed = base_seed + idx
197208
clip_input_info.audio_clip = audio_clip
198-
idx = idx + 1
209+
clip_input_info.person_mask_latens = person_mask_latens
210+
199211
if self.progress_callback:
200-
self.progress_callback(idx, total_clips)
212+
self.progress_callback(idx + 1, total_clips)
201213

202214
rs2v.input_info = clip_input_info
203-
clip_input_info.person_mask_latens = person_mask_latens
204-
205-
# 使用动态输入获取当前 clip 控制参数
206215
rs2v.inputs = rs2v._run_input_encoder_local_rs2v_dynamic()
207-
208216
gen_clip_video, audio_clip, gen_latents = rs2v.run_clip_main()
209-
logger.info(f"Generated rs2v clip {idx}, pad_len {pad_len}, gen_clip_video shape: {gen_clip_video.shape}, audio_clip shape: {audio_clip.shape} gen_latents shape: {gen_latents.shape}")
210217

211-
if segment_actual_video_frames is not None:
212-
# Last segment: trim to exact actual frames needed
213-
video_seg = gen_clip_video[:, :, :segment_actual_video_frames]
214-
audio_seg = audio_clip[:, : segment_actual_video_frames * audio_per_frame].sum(dim=0)
215-
else:
216-
video_seg = gen_clip_video
217-
audio_seg = audio_clip.sum(dim=0)
218+
logger.info(f"Generated rs2v clip {idx + 1}, pad_len={pad_len}, gen_clip_video={gen_clip_video.shape}, audio_clip={audio_clip.shape}, gen_latents={gen_latents.shape}")
218219

220+
video_seg, audio_seg = self._trim_segment(gen_clip_video, audio_clip, segment_actual_frames, audio_per_frame)
219221
clip_input_info.overlap_latent = gen_latents[:, -1:]
220222

221223
if clip_input_info.return_result_tensor or not clip_input_info.stream_save_video:
@@ -229,23 +231,10 @@ def load_audio(audio_path, target_sr):
229231
if not clip_input_info.return_result_tensor and clip_input_info.stream_save_video:
230232
return None, None, None
231233

232-
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
233-
gen_lvideo = torch.clamp(gen_lvideo, -1, 1)
234-
merge_audio = torch.cat(cut_audio_list, dim=0).numpy().astype(np.float32)
235-
236-
if is_main_process() and clip_input_info.save_result_path:
237-
out_path = os.path.join("./", "video_merge.mp4")
238-
audio_file = os.path.join("./", "audio_merge.wav")
239-
240-
save_to_video(gen_lvideo, out_path, 16)
241-
save_audio(merge_audio, audio_file, out_path, output_path=clip_input_info.save_result_path)
242-
os.remove(out_path)
243-
os.remove(audio_file)
244-
234+
gen_lvideo, merge_audio = self._merge_and_save(gen_video_list, cut_audio_list, target_fps, clip_input_info.save_result_path)
245235
return gen_lvideo, merge_audio, audio_sr
246236

247237
def run_pipeline(self, input_info):
248-
# input_info = self.update_input_info(input_info)
249238
try:
250239
gen_lvideo, merge_audio, audio_sr = self.generate(input_info)
251240
finally:

0 commit comments

Comments
 (0)