Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 6 additions & 25 deletions lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from lightllm.server.multimodal_params import AudioItem
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
from lightllm.models.qwen3_omni_moe_thinker.audio_process import WhisperFeatureExtractor
Expand Down Expand Up @@ -207,9 +206,6 @@ def __init__(
self.proj2 = nn.Linear(d_model, output_dim)
self.n_window_infer = n_window_infer
self.conv_chunksize = conv_chunksize

self.cache_port = kvargs["cache_port"]
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
self._init_datatype()

def _init_datatype(self):
Expand Down Expand Up @@ -337,7 +333,7 @@ def forward(
hidden_states = self.proj2(hidden_states)
return hidden_states

def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedCacheClient):
def encode(self, audio_items: List[AudioItem]):
uuids = []
items: List[AudioItem] = []
per_audio_features: List[torch.Tensor] = []
Expand Down Expand Up @@ -368,24 +364,9 @@ def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedC
)
per_audio_features.append(audio_features)

ready_audio = obtain(self.cache_client.root.get_items_embed(uuids))
ids_to_set = []
for i, ready in enumerate(ready_audio):
if ready:
continue

uid = uuids[i]
item = items[i]

all_embeds = []
for i in range(len(audio_items)):
cur_embed = per_audio_features[i]
cpu_embed_cache_client.copy_to_cache(
embed_tensor=cur_embed, start_index_in_cache=item.start_index_in_embed_cache
)
assert (
item.token_num == cur_embed.shape[0]
), f"audio token num not match {item.token_num} vs {cur_embed.shape[0]} "
ids_to_set.append(uid)

if ids_to_set:
self.cache_client.root.set_items_embed(ids=ids_to_set)
torch.cuda.current_stream().synchronize()
all_embeds.append(cur_embed)

return all_embeds, audio_items
Comment on lines +367 to +372
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loop to build all_embeds is redundant because per_audio_features is already a list containing the embeddings for each item in audio_items in the same order.

Suggested change
all_embeds = []
for i in range(len(audio_items)):
cur_embed = per_audio_features[i]
cpu_embed_cache_client.copy_to_cache(
embed_tensor=cur_embed, start_index_in_cache=item.start_index_in_embed_cache
)
assert (
item.token_num == cur_embed.shape[0]
), f"audio token num not match {item.token_num} vs {cur_embed.shape[0]} "
ids_to_set.append(uid)
if ids_to_set:
self.cache_client.root.set_items_embed(ids=ids_to_set)
torch.cuda.current_stream().synchronize()
all_embeds.append(cur_embed)
return all_embeds, audio_items
return per_audio_features, audio_items

24 changes: 6 additions & 18 deletions lightllm/models/whisper/whisper_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from lightllm.server.multimodal_params import AudioItem
from rpyc.utils.classic import obtain
from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient

# tokenizer_class removed
class WhisperProcessor(ProcessorMixin):
Expand Down Expand Up @@ -89,8 +88,6 @@ def __init__(self, kvargs):
self.max_seconds = 30
self.sampling_rate = 16000
self.max_length = self.max_seconds * self.sampling_rate
self.cache_port = kvargs["cache_port"]
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
data_type = kvargs["data_type"]
if data_type in ["bf16", "bfloat16"]:
self.data_type = torch.bfloat16
Expand Down Expand Up @@ -162,7 +159,7 @@ def forward(self, audio_values, audio_lens_after_cnn):
x = F.linear(x, weight=self.projector_weights["mlp2.3.weight"], bias=self.projector_weights["mlp2.3.bias"])
return x

def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedCacheClient):
def encode(self, audio_items: List[AudioItem]):
# 每个元素是一个chunk
batch_audios = []
batch_audio_lens = []
Expand Down Expand Up @@ -222,22 +219,13 @@ def encode(self, audio_items: List[AudioItem], cpu_embed_cache_client: CpuEmbedC
continue
per_audio_embeds[owner].append(audios[chunk_idx][:token_len])

ready_audio = obtain(self.cache_client.root.get_items_embed(uuids))
ids_to_set = []
for i, ready in enumerate(ready_audio):
if ready:
continue
ans_embeds = []
for i in range(len(uuids)):

uid = uuids[i]
item = items[i]

# 拼接该 audio 的所有 chunk embedding
cur_embed = torch.cat(per_audio_embeds[i], dim=0)
cpu_embed_cache_client.copy_to_cache(
embed_tensor=cur_embed, start_index_in_cache=item.start_index_in_embed_cache
)
ids_to_set.append(uid)

if ids_to_set:
self.cache_client.root.set_items_embed(ids=ids_to_set)
torch.cuda.current_stream().synchronize()
ans_embeds.append(cur_embed)

return ans_embeds, audio_items
26 changes: 26 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,32 @@ def make_argument_parser() -> argparse.ArgumentParser:
transfer image to embed.
""",
)
parser.add_argument(
"--audio_gpu_ids", nargs="+", type=int, default=None, help="GPU IDs for audio encoder, e.g., 0 1 2"
)
parser.add_argument(
"--audio_tp",
type=int,
default=1,
help="Tensor parallel size for audio encoder (only 1 is supported; use audio_dp to scale)",
)
parser.add_argument("--audio_dp", type=int, default=1, help="Data parallel replicas for audio encoder")
parser.add_argument(
"--audio_nccl_ports",
nargs="+",
type=int,
default=None,
help="NCCL ports per audio DP group; if omitted, auto-allocated in api_start (reserved until audio_tp>1)",
)
parser.add_argument(
"--audio_infer_batch_size",
type=int,
default=None,
help="""
Max audio items per GPU infer batch in audio worker (default: max(4, audio_dp),
must be multiple of audio_dp)
""",
)
parser.add_argument(
"--visual_use_proxy_mode",
action="store_true",
Expand Down
48 changes: 39 additions & 9 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,32 @@ def normal_or_p_d_start(args):
f"a positive integer multiple of visual_dp ({args.visual_dp})"
)

if not args.disable_audio:
if args.audio_tp != 1:
raise ValueError(
"audio_tp > 1 is not supported for the audio encoder yet; use --audio_dp for multi-GPU data parallel."
)
if args.audio_gpu_ids is None:
args.audio_gpu_ids = list(range(args.audio_dp * args.audio_tp))
total_audio_gpus = args.audio_dp * args.audio_tp
if len(args.audio_gpu_ids) < total_audio_gpus:
raise ValueError(
f"Not enough audio GPUs specified. Need at least {total_audio_gpus}, "
f"but got {len(args.audio_gpu_ids)}."
)
args.audio_gpu_ids = args.audio_gpu_ids[:total_audio_gpus]
if args.audio_dp <= 0:
raise ValueError("audio_dp must be a positive integer.")
if args.audio_infer_batch_size is None:
args.audio_infer_batch_size = args.audio_dp * 4
if args.audio_infer_batch_size < 1:
raise ValueError("audio_infer_batch_size must be >= 1.")
if args.audio_infer_batch_size // args.audio_dp < 1 or args.audio_infer_batch_size % args.audio_dp != 0:
raise ValueError(
f"audio_infer_batch_size ({args.audio_infer_batch_size}) must be "
f"a positive integer multiple of audio_dp ({args.audio_dp})."
)

if args.disable_chunked_prefill:
args.chunked_prefill_size = args.max_req_total_len
# 普通模式下
Expand Down Expand Up @@ -248,6 +274,8 @@ def normal_or_p_d_start(args):
already_uesd_ports.append(args.pd_decode_rpyc_port)
if args.visual_nccl_ports is not None:
already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp])
if not args.disable_audio and args.audio_nccl_ports is not None:
already_uesd_ports.extend(args.audio_nccl_ports[: args.audio_dp])

# 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能
# 捕获到端口设置冲突的问题
Expand All @@ -256,7 +284,7 @@ def normal_or_p_d_start(args):

node_world_size = args.tp // args.nnodes
can_use_ports = alloc_can_use_network_port(
num=10 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp,
num=10 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp,
used_ports=already_uesd_ports,
)
logger.info(f"alloced ports: {can_use_ports}")
Expand All @@ -274,14 +302,17 @@ def normal_or_p_d_start(args):
) = can_use_ports[0:10]
can_use_ports = can_use_ports[10:]

visual_nccl_ports = []
for _ in range(args.visual_dp):
if args.visual_nccl_ports is None:
visual_nccl_ports.append(can_use_ports[0])
can_use_ports = can_use_ports[1:]
if args.visual_nccl_ports is None:
args.visual_nccl_ports = can_use_ports[: args.visual_dp]
can_use_ports = can_use_ports[args.visual_dp :]
else:
args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]

if args.visual_nccl_ports is not None:
visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]
if args.audio_nccl_ports is None:
args.audio_nccl_ports = can_use_ports[: args.audio_dp]
can_use_ports = can_use_ports[args.audio_dp :]
else:
args.audio_nccl_ports = args.audio_nccl_ports[: args.audio_dp]

# 将申请好的端口放入args参数中
if args.nccl_port is None:
Expand All @@ -296,7 +327,6 @@ def normal_or_p_d_start(args):
args.cache_port = cache_port
args.metric_port = metric_port
args.multi_level_kv_cache_port = multi_level_kv_cache_port
args.visual_nccl_ports = visual_nccl_ports
# 申请在 p d 分离模式下,会用的端口
args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size]
# p d 分离模式下用于标识节点的id
Expand Down
Loading
Loading