@@ -204,6 +204,32 @@ def normal_or_p_d_start(args):
204204 f"a positive integer multiple of visual_dp ({ args .visual_dp } )"
205205 )
206206
207+ if not args .disable_audio :
208+ if args .audio_tp != 1 :
209+ raise ValueError (
210+ "audio_tp > 1 is not supported for the audio encoder yet; use --audio_dp for multi-GPU data parallel."
211+ )
212+ if args .audio_gpu_ids is None :
213+ args .audio_gpu_ids = list (range (args .audio_dp * args .audio_tp ))
214+ total_audio_gpus = args .audio_dp * args .audio_tp
215+ if len (args .audio_gpu_ids ) < total_audio_gpus :
216+ raise ValueError (
217+ f"Not enough audio GPUs specified. Need at least { total_audio_gpus } , "
218+ f"but got { len (args .audio_gpu_ids )} ."
219+ )
220+ args .audio_gpu_ids = args .audio_gpu_ids [:total_audio_gpus ]
221+ if args .audio_dp <= 0 :
222+ raise ValueError ("audio_dp must be a positive integer." )
223+ if args .audio_infer_batch_size is None :
224+ args .audio_infer_batch_size = args .audio_dp * 4
225+ if args .audio_infer_batch_size < 1 :
226+ raise ValueError ("audio_infer_batch_size must be >= 1." )
227+ if args .audio_infer_batch_size // args .audio_dp < 1 or args .audio_infer_batch_size % args .audio_dp != 0 :
228+ raise ValueError (
229+ f"audio_infer_batch_size ({ args .audio_infer_batch_size } ) must be "
230+ f"a positive integer multiple of audio_dp ({ args .audio_dp } )."
231+ )
232+
207233 if args .disable_chunked_prefill :
208234 args .chunked_prefill_size = args .max_req_total_len
209235 # 普通模式下
@@ -248,6 +274,8 @@ def normal_or_p_d_start(args):
248274 already_uesd_ports .append (args .pd_decode_rpyc_port )
249275 if args .visual_nccl_ports is not None :
250276 already_uesd_ports .extend (args .visual_nccl_ports [: args .visual_dp ])
277+ if not args .disable_audio and args .audio_nccl_ports is not None :
278+ already_uesd_ports .extend (args .audio_nccl_ports [: args .audio_dp ])
251279
252280 # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能
253281 # 捕获到端口设置冲突的问题
@@ -256,7 +284,7 @@ def normal_or_p_d_start(args):
256284
257285 node_world_size = args .tp // args .nnodes
258286 can_use_ports = alloc_can_use_network_port (
259- num = 10 + node_world_size + args .visual_dp * args .visual_tp + args .visual_dp ,
287+ num = 10 + node_world_size + args .visual_dp * args .visual_tp + args .visual_dp + args . audio_dp ,
260288 used_ports = already_uesd_ports ,
261289 )
262290 logger .info (f"alloced ports: { can_use_ports } " )
@@ -274,14 +302,17 @@ def normal_or_p_d_start(args):
274302 ) = can_use_ports [0 :10 ]
275303 can_use_ports = can_use_ports [10 :]
276304
277- visual_nccl_ports = []
278- for _ in range ( args .visual_dp ):
279- if args .visual_nccl_ports is None :
280- visual_nccl_ports . append ( can_use_ports [ 0 ])
281- can_use_ports = can_use_ports [ 1 : ]
305+ if args . visual_nccl_ports is None :
306+ args .visual_nccl_ports = can_use_ports [: args . visual_dp ]
307+ can_use_ports = can_use_ports [ args .visual_dp :]
308+ else :
309+ args . visual_nccl_ports = args . visual_nccl_ports [: args . visual_dp ]
282310
283- if args .visual_nccl_ports is not None :
284- visual_nccl_ports = args .visual_nccl_ports [: args .visual_dp ]
311+ if args .audio_nccl_ports is None :
312+ args .audio_nccl_ports = can_use_ports [: args .audio_dp ]
313+ can_use_ports = can_use_ports [args .audio_dp :]
314+ else :
315+ args .audio_nccl_ports = args .audio_nccl_ports [: args .audio_dp ]
285316
286317 # 将申请好的端口放入args参数中
287318 if args .nccl_port is None :
@@ -296,7 +327,6 @@ def normal_or_p_d_start(args):
296327 args .cache_port = cache_port
297328 args .metric_port = metric_port
298329 args .multi_level_kv_cache_port = multi_level_kv_cache_port
299- args .visual_nccl_ports = visual_nccl_ports
300330 # 申请在 p d 分离模式下,会用的端口
301331 args .pd_node_infer_rpyc_ports = can_use_ports [0 :node_world_size ]
302332 # p d 分离模式下用于标识节点的id
0 commit comments