Skip to content

Commit 93dd2a6

Browse files
authored
audio server improve. (#1254)
1 parent 0fce8c9 commit 93dd2a6

14 files changed

Lines changed: 1012 additions & 219 deletions

File tree

lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from lightllm.server.multimodal_params import AudioItem
1818
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
19-
from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient
2019
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
2120
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
2221
from lightllm.models.qwen3_omni_moe_thinker.audio_process import WhisperFeatureExtractor
@@ -207,9 +206,6 @@ def __init__(
207206
self.proj2 = nn.Linear(d_model, output_dim)
208207
self.n_window_infer = n_window_infer
209208
self.conv_chunksize = conv_chunksize
210-
211-
self.cache_port = kvargs["cache_port"]
212-
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
213209
self._init_datatype()
214210

215211
def _init_datatype(self):
@@ -337,7 +333,7 @@ def forward(
337333
hidden_states = self.proj2(hidden_states)
338334
return hidden_states
339335

340-
def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedCacheClient):
336+
def encode(self, audio_items: List[AudioItem]):
341337
uuids = []
342338
items: List[AudioItem] = []
343339
per_audio_features: List[torch.Tensor] = []
@@ -368,24 +364,9 @@ def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedC
368364
)
369365
per_audio_features.append(audio_features)
370366

371-
ready_audio = obtain(self.cache_client.root.get_items_embed(uuids))
372-
ids_to_set = []
373-
for i, ready in enumerate(ready_audio):
374-
if ready:
375-
continue
376-
377-
uid = uuids[i]
378-
item = items[i]
379-
367+
all_embeds = []
368+
for i in range(len(audio_items)):
380369
cur_embed = per_audio_features[i]
381-
cpu_embed_cache_client.copy_to_cache(
382-
embed_tensor=cur_embed, start_index_in_cache=item.start_index_in_embed_cache
383-
)
384-
assert (
385-
item.token_num == cur_embed.shape[0]
386-
), f"audio token num not match {item.token_num} vs {cur_embed.shape[0]} "
387-
ids_to_set.append(uid)
388-
389-
if ids_to_set:
390-
self.cache_client.root.set_items_embed(ids=ids_to_set)
391-
torch.cuda.current_stream().synchronize()
370+
all_embeds.append(cur_embed)
371+
372+
return all_embeds, audio_items

lightllm/models/whisper/whisper_audio.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
1313
from lightllm.server.multimodal_params import AudioItem
1414
from rpyc.utils.classic import obtain
15-
from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient
1615

1716
# tokenizer_class removed
1817
class WhisperProcessor(ProcessorMixin):
@@ -89,8 +88,6 @@ def __init__(self, kvargs):
8988
self.max_seconds = 30
9089
self.sampling_rate = 16000
9190
self.max_length = self.max_seconds * self.sampling_rate
92-
self.cache_port = kvargs["cache_port"]
93-
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
9491
data_type = kvargs["data_type"]
9592
if data_type in ["bf16", "bfloat16"]:
9693
self.data_type = torch.bfloat16
@@ -162,7 +159,7 @@ def forward(self, audio_values, audio_lens_after_cnn):
162159
x = F.linear(x, weight=self.projector_weights["mlp2.3.weight"], bias=self.projector_weights["mlp2.3.bias"])
163160
return x
164161

165-
def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedCacheClient):
162+
def encode(self, audio_items: List[AudioItem]):
166163
# 每个元素是一个chunk
167164
batch_audios = []
168165
batch_audio_lens = []
@@ -222,22 +219,13 @@ def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedC
222219
continue
223220
per_audio_embeds[owner].append(audios[chunk_idx][:token_len])
224221

225-
ready_audio = obtain(self.cache_client.root.get_items_embed(uuids))
226-
ids_to_set = []
227-
for i, ready in enumerate(ready_audio):
228-
if ready:
229-
continue
222+
ans_embeds = []
223+
for i in range(len(uuids)):
230224

231-
uid = uuids[i]
232225
item = items[i]
233226

234227
# 拼接该 audio 的所有 chunk embedding
235228
cur_embed = torch.cat(per_audio_embeds[i], dim=0)
236-
cpu_embed_cache_client.copy_to_cache(
237-
embed_tensor=cur_embed, start_index_in_cache=item.start_index_in_embed_cache
238-
)
239-
ids_to_set.append(uid)
240-
241-
if ids_to_set:
242-
self.cache_client.root.set_items_embed(ids=ids_to_set)
243-
torch.cuda.current_stream().synchronize()
229+
ans_embeds.append(cur_embed)
230+
231+
return ans_embeds, audio_items

lightllm/server/api_cli.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,32 @@ def make_argument_parser() -> argparse.ArgumentParser:
484484
transfer image to embed.
485485
""",
486486
)
487+
parser.add_argument(
488+
"--audio_gpu_ids", nargs="+", type=int, default=None, help="GPU IDs for audio encoder, e.g., 0 1 2"
489+
)
490+
parser.add_argument(
491+
"--audio_tp",
492+
type=int,
493+
default=1,
494+
help="Tensor parallel size for audio encoder (only 1 is supported; use audio_dp to scale)",
495+
)
496+
parser.add_argument("--audio_dp", type=int, default=1, help="Data parallel replicas for audio encoder")
497+
parser.add_argument(
498+
"--audio_nccl_ports",
499+
nargs="+",
500+
type=int,
501+
default=None,
502+
help="NCCL ports per audio DP group; if omitted, auto-allocated in api_start (reserved until audio_tp>1)",
503+
)
504+
parser.add_argument(
505+
"--audio_infer_batch_size",
506+
type=int,
507+
default=None,
508+
help="""
509+
Max audio items per GPU infer batch in audio worker (default: max(4, audio_dp),
510+
must be multiple of audio_dp)
511+
""",
512+
)
487513
parser.add_argument(
488514
"--visual_use_proxy_mode",
489515
action="store_true",

lightllm/server/api_start.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,32 @@ def normal_or_p_d_start(args):
204204
f"a positive integer multiple of visual_dp ({args.visual_dp})"
205205
)
206206

207+
if not args.disable_audio:
208+
if args.audio_tp != 1:
209+
raise ValueError(
210+
"audio_tp > 1 is not supported for the audio encoder yet; use --audio_dp for multi-GPU data parallel."
211+
)
212+
if args.audio_gpu_ids is None:
213+
args.audio_gpu_ids = list(range(args.audio_dp * args.audio_tp))
214+
total_audio_gpus = args.audio_dp * args.audio_tp
215+
if len(args.audio_gpu_ids) < total_audio_gpus:
216+
raise ValueError(
217+
f"Not enough audio GPUs specified. Need at least {total_audio_gpus}, "
218+
f"but got {len(args.audio_gpu_ids)}."
219+
)
220+
args.audio_gpu_ids = args.audio_gpu_ids[:total_audio_gpus]
221+
if args.audio_dp <= 0:
222+
raise ValueError("audio_dp must be a positive integer.")
223+
if args.audio_infer_batch_size is None:
224+
args.audio_infer_batch_size = args.audio_dp * 4
225+
if args.audio_infer_batch_size < 1:
226+
raise ValueError("audio_infer_batch_size must be >= 1.")
227+
if args.audio_infer_batch_size // args.audio_dp < 1 or args.audio_infer_batch_size % args.audio_dp != 0:
228+
raise ValueError(
229+
f"audio_infer_batch_size ({args.audio_infer_batch_size}) must be "
230+
f"a positive integer multiple of audio_dp ({args.audio_dp})."
231+
)
232+
207233
if args.disable_chunked_prefill:
208234
args.chunked_prefill_size = args.max_req_total_len
209235
# 普通模式下
@@ -248,6 +274,8 @@ def normal_or_p_d_start(args):
248274
already_uesd_ports.append(args.pd_decode_rpyc_port)
249275
if args.visual_nccl_ports is not None:
250276
already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp])
277+
if not args.disable_audio and args.audio_nccl_ports is not None:
278+
already_uesd_ports.extend(args.audio_nccl_ports[: args.audio_dp])
251279

252280
# 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能
253281
# 捕获到端口设置冲突的问题
@@ -256,7 +284,7 @@ def normal_or_p_d_start(args):
256284

257285
node_world_size = args.tp // args.nnodes
258286
can_use_ports = alloc_can_use_network_port(
259-
num=10 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp,
287+
num=10 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp,
260288
used_ports=already_uesd_ports,
261289
)
262290
logger.info(f"alloced ports: {can_use_ports}")
@@ -274,14 +302,17 @@ def normal_or_p_d_start(args):
274302
) = can_use_ports[0:10]
275303
can_use_ports = can_use_ports[10:]
276304

277-
visual_nccl_ports = []
278-
for _ in range(args.visual_dp):
279-
if args.visual_nccl_ports is None:
280-
visual_nccl_ports.append(can_use_ports[0])
281-
can_use_ports = can_use_ports[1:]
305+
if args.visual_nccl_ports is None:
306+
args.visual_nccl_ports = can_use_ports[: args.visual_dp]
307+
can_use_ports = can_use_ports[args.visual_dp :]
308+
else:
309+
args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]
282310

283-
if args.visual_nccl_ports is not None:
284-
visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]
311+
if args.audio_nccl_ports is None:
312+
args.audio_nccl_ports = can_use_ports[: args.audio_dp]
313+
can_use_ports = can_use_ports[args.audio_dp :]
314+
else:
315+
args.audio_nccl_ports = args.audio_nccl_ports[: args.audio_dp]
285316

286317
# 将申请好的端口放入args参数中
287318
if args.nccl_port is None:
@@ -296,7 +327,6 @@ def normal_or_p_d_start(args):
296327
args.cache_port = cache_port
297328
args.metric_port = metric_port
298329
args.multi_level_kv_cache_port = multi_level_kv_cache_port
299-
args.visual_nccl_ports = visual_nccl_ports
300330
# 申请在 p d 分离模式下,会用的端口
301331
args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size]
302332
# p d 分离模式下用于标识节点的id

0 commit comments

Comments
 (0)