Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the audio server to support data and tensor parallelism, moving embedding cache management from the models to background worker threads in the RPC server. It also adds CLI configuration for audio resources and performance testing utilities. Feedback focuses on addressing potential resource leaks in the cleanup process, handling silent failures in background threads, removing redundant loops, improving concurrency when processing multiple data-parallel groups, and adding error handling for process initialization timeouts.
| def clean_up(self): | ||
| for model_rpc in self.model_rpcs: | ||
| model_rpc.rpc_server_process.kill() | ||
| for model_rpc in self.model_rpcs: | ||
| model_rpc.rpc_server_process.join() | ||
| return |
There was a problem hiding this comment.
The clean_up method is now empty, but it should terminate the model RPC processes started in wait_to_model_ready. Additionally, the start_model_process function in model_infer/__init__.py no longer returns the process object, making it difficult to track and kill these processes. This could lead to resource leaks or zombie processes when the server is stopped.
| except Exception as e: | ||
| logger.exception(str(e)) | ||
| raise e |
There was a problem hiding this comment.
Exceptions in the _infer_worker or _store_worker threads will cause the threads to terminate. Since these are critical background workers, their failure will cause the audio server to stop functioning silently. Consider adding a mechanism to detect thread failure and either restart them or shut down the process gracefully.
| all_embeds = [] | ||
| for i in range(len(audio_items)): | ||
| cur_embed = per_audio_features[i] | ||
| cpu_embed_cache_client.copy_to_cache( | ||
| embed_tensor=cur_embed, start_index_in_cache=item.start_index_in_embed_cache | ||
| ) | ||
| assert ( | ||
| item.token_num == cur_embed.shape[0] | ||
| ), f"audio token num not match {item.token_num} vs {cur_embed.shape[0]} " | ||
| ids_to_set.append(uid) | ||
|
|
||
| if ids_to_set: | ||
| self.cache_client.root.set_items_embed(ids=ids_to_set) | ||
| torch.cuda.current_stream().synchronize() | ||
| all_embeds.append(cur_embed) | ||
|
|
||
| return all_embeds, audio_items |
There was a problem hiding this comment.
The loop to build all_embeds is redundant because per_audio_features is already a list containing the embeddings for each item in audio_items in the same order.
| all_embeds = [] | |
| for i in range(len(audio_items)): | |
| cur_embed = per_audio_features[i] | |
| cpu_embed_cache_client.copy_to_cache( | |
| embed_tensor=cur_embed, start_index_in_cache=item.start_index_in_embed_cache | |
| ) | |
| assert ( | |
| item.token_num == cur_embed.shape[0] | |
| ), f"audio token num not match {item.token_num} vs {cur_embed.shape[0]} " | |
| ids_to_set.append(uid) | |
| if ids_to_set: | |
| self.cache_client.root.set_items_embed(ids=ids_to_set) | |
| torch.cuda.current_stream().synchronize() | |
| all_embeds.append(cur_embed) | |
| return all_embeds, audio_items | |
| return per_audio_features, audio_items |
| for dp_index in range(self.audio_dp): | ||
| _audios = dp_to_handle_audios[dp_index] | ||
| if _audios: | ||
| await asyncio.to_thread(_audios[-1][1].wait) |
There was a problem hiding this comment.
The current implementation waits for the completion of each DP group sequentially. If multiple DP groups are processing audio items from the same request, they should be waited on concurrently to improve performance.
| for dp_index in range(self.audio_dp): | |
| _audios = dp_to_handle_audios[dp_index] | |
| if _audios: | |
| await asyncio.to_thread(_audios[-1][1].wait) | |
| wait_tasks = [] | |
| for dp_index in range(self.audio_dp): | |
| _audios = dp_to_handle_audios[dp_index] | |
| if _audios: | |
| wait_tasks.append(asyncio.to_thread(_audios[-1][1].wait)) | |
| if wait_tasks: | |
| await asyncio.gather(*wait_tasks) |
| await asyncio.to_thread(success_event.wait, timeout=40) | ||
| assert proc.is_alive() |
There was a problem hiding this comment.
The return value of success_event.wait is not checked. If the process fails to start within the 40-second timeout, the code will proceed to attempt a connection, which will likely fail or hang. It's better to explicitly check for the timeout.
| await asyncio.to_thread(success_event.wait, timeout=40) | |
| assert proc.is_alive() | |
| if not await asyncio.to_thread(success_event.wait, timeout=40): | |
| proc.terminate() | |
| raise RuntimeError("Audio model inference process failed to start within timeout") |
No description provided.