Skip to content

Commit 9f43400

Browse files
author
niushengxiao
committed
refine
1 parent ecf34ae commit 9f43400

10 files changed

Lines changed: 33 additions & 76 deletions

File tree

lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer):
3131
def __init__(self, network_config):
3232
super().__init__(network_config)
3333
self.args = get_env_start_args()
34-
self.cache_client = None
3534
if self.args.enable_remote_vit:
3635
self.cache_client = rpyc.connect("localhost", self.args.cache_port, config={"allow_pickle": True})
3736
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
@@ -52,12 +51,14 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
5251
img_token_lens = []
5352
img_start_locs_in_cache = []
5453
unique_uids = []
54+
all_uids = []
5555
device = layer_weight.wte_weight_.weight.device
5656
dtype = layer_weight.wte_weight_.weight.dtype
5757
hidden_size = layer_weight.wte_weight_.weight.shape[1]
5858

5959
for _, p in enumerate(infer_state.multimodal_params):
6060
for img in p["images"] + p["audios"]:
61+
all_uids.append(img["uuid"])
6162
# skip the same image
6263
if img["token_id"] in img_start_token_ids:
6364
continue
@@ -77,17 +78,12 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
7778
)
7879

7980
if self.args.enable_remote_vit:
80-
release_ids = []
81-
for _, p in enumerate(infer_state.multimodal_params):
82-
for img in p["images"] + p["audios"]:
83-
release_ids.append(img["uuid"])
84-
8581
for uid, start_index_in_embed_cache in zip(unique_uids, img_start_locs_in_cache):
8682
embed_tensor = load_tensor_afs(get_shm_name_embed(uid), self.args.image_embed_dir)
8783
self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, start_index_in_embed_cache)
8884

89-
if release_ids:
90-
self.cache_client.root.release(release_ids)
85+
if all_uids:
86+
self.cache_client.root.release(all_uids)
9187

9288
assert cpu_embed_cache_tensor.shape[2] == hidden_size, (
9389
f"Dimension mismatch: text weight dimension is {hidden_size}, "

lightllm/server/api_http.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def set_args(self, args: StartArgs):
9292
self.httpserver_manager = HttpServerManagerForPDMaster(
9393
args=args,
9494
)
95-
elif args.run_mode == "visual":
95+
elif args.run_mode in ["visual", "visual_only"]:
9696
self.metric_client = MetricClient(args.metric_port)
9797
else:
9898
init_tokenizer(args) # for openai api
@@ -138,7 +138,7 @@ def get_model_name():
138138
@app.get("/health", summary="Check server health")
139139
@app.head("/health", summary="Check server health")
140140
async def healthcheck(request: Request):
141-
if g_objs.args.run_mode in ["pd_master", "visual"]:
141+
if g_objs.args.run_mode in ["pd_master", "visual", "visual_only"]:
142142
return JSONResponse({"message": "Ok"}, status_code=200)
143143

144144
if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true":

lightllm/server/api_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
pd_master_start(args)
1212
elif args.run_mode == "config_server":
1313
config_server_start(args)
14-
elif args.run_mode == "visual":
14+
elif args.run_mode in ["visual", "visual_only"]:
1515
visual_start(args)
1616
else:
1717
normal_or_p_d_start(args)

lightllm/server/api_start.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import subprocess
66
import signal
77
from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker
8-
from lightllm.utils.start_utils import process_manager, kill_recursive, is_multimodal_mode
8+
from lightllm.utils.start_utils import process_manager, kill_recursive
99
from .metrics.manager import start_metric_manager
1010
from .embed_cache.manager import start_cache_manager
1111
from lightllm.utils.log_utils import init_logger
@@ -194,7 +194,6 @@ def normal_or_p_d_start(args, only_prepare=False):
194194
assert args.mtp_draft_model_dir is None
195195
assert args.mtp_step == 0
196196

197-
args.enable_multimodal = is_multimodal_mode(args)
198197
_prepare_remote_vit_embed_dir(args)
199198
# 检查GPU数量是否足够
200199
if args.visual_gpu_ids is None:
@@ -355,27 +354,27 @@ def normal_or_p_d_start(args, only_prepare=False):
355354
start_args=[(args,)],
356355
)
357356

358-
if not args.disable_audio:
359-
from .audioserver.manager import start_audio_process
357+
if not args.disable_vision and not args.enable_remote_vit:
358+
from .visualserver.manager import start_visual_process
360359

361360
process_manager.start_submodule_processes(
362361
start_funcs=[
363-
start_audio_process,
362+
start_visual_process,
364363
],
365364
start_args=[
366-
(args,),
365+
(args, visual_model_tp_ports),
367366
],
368367
)
369368

370-
if not args.disable_vision and not args.enable_remote_vit:
371-
from .visualserver.manager import start_visual_process
369+
if not args.disable_audio:
370+
from .audioserver.manager import start_audio_process
372371

373372
process_manager.start_submodule_processes(
374373
start_funcs=[
375-
start_visual_process,
374+
start_audio_process,
376375
],
377376
start_args=[
378-
(args, visual_model_tp_ports),
377+
(args,),
379378
],
380379
)
381380

lightllm/server/embed_cache/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def exposed_get_items_embed(self, ids: list[int], embeding_only: bool = False) -
5454

5555

5656
def get_cache_manager(args):
57-
if args.enable_remote_vit or args.run_mode == "visual":
57+
if args.enable_remote_vit or args.run_mode in ["visual", "visual_only"]:
5858
return MemoryCacheWithRedis(args)
5959
else:
6060
return InMemoryCache(args)

lightllm/server/httpserver/manager.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,15 @@ def __init__(
8282
if self.enable_multimodal:
8383
self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True})
8484
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
85-
if not self.args.disable_vision:
86-
from lightllm.server.visualserver.vit_connect import VITConnectionManager
8785

88-
self.vit_manager = VITConnectionManager(args, context, args.visual_port, self.cache_client)
86+
if not self.args.disable_vision:
87+
from lightllm.server.visualserver.vit_connect import VITConnectionManager
8988

90-
if not self.args.disable_audio:
91-
self.send_to_audio = context.socket(zmq.PUSH)
92-
self.send_to_audio.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}")
89+
self.vit_manager = VITConnectionManager(args, context, args.visual_port, self.cache_client)
90+
91+
if not self.args.disable_audio:
92+
self.send_to_audio = context.socket(zmq.PUSH)
93+
self.send_to_audio.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}")
9394

9495
if args.enable_cpu_cache and not self.args.enable_multimodal:
9596
self.send_to_multi_level_kv_cache = context.socket(zmq.PUSH)
@@ -151,7 +152,6 @@ async def _alloc_resource(self, items, uuids, token_nums, datas):
151152
if self.args.enable_remote_vit:
152153
# 避免远端lru被逐出
153154
self.cache_client.root.get_items_embed(uid_list, False)
154-
return
155155

156156
ready_flags = obtain(self.cache_client.root.get_items_data(uid_list))
157157
update_data_ids = []
@@ -592,25 +592,13 @@ async def transfer_to_next_module(
592592

593593
if self.pd_mode.is_P_or_NORMAL():
594594
group_req_index = group_req_objs.to_group_req_index()
595-
has_images = len(group_req_index.multimodal_params.images) > 0
596-
has_audios = len(group_req_index.multimodal_params.audios) > 0
597-
598-
if has_images and not self.args.disable_vision:
599-
free_mode = "all"
600-
if self.args.enable_remote_vit and has_audios and not self.args.disable_audio:
601-
free_mode = "images"
602-
603-
await self.vit_manager.send_to_vit(
604-
group_req_index, protocol=pickle.HIGHEST_PROTOCOL, free_mode=free_mode
605-
)
606-
595+
if not self.args.disable_vision:
596+
await self.vit_manager.send_to_vit(group_req_index, protocol=pickle.HIGHEST_PROTOCOL)
607597
if not self.args.enable_remote_vit:
608598
return
609599

610-
if has_audios and not self.args.disable_audio:
600+
if not self.args.disable_audio:
611601
self.send_to_audio.send_pyobj(group_req_index, protocol=pickle.HIGHEST_PROTOCOL)
612-
if self.args.enable_remote_vit:
613-
group_req_index.multimodal_params.free()
614602
return
615603

616604
if self.args.enable_cpu_cache:

lightllm/server/multimodal_params.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def __init__(self, **kwargs):
2626
self.token_num = None
2727
# the audio length
2828
self.audio_length = None
29-
self.afs_embed = False
3029

3130
self._preload_data = None
3231
self.extra_params = {}
@@ -55,11 +54,10 @@ async def preload(self, request: Request):
5554

5655
def read(self):
5756
assert self._preload_data is not None
58-
return self._preload_data
59-
60-
def free(self):
57+
ans = self._preload_data
6158
self._preload_data = None
6259
self._data = None
60+
return ans
6361

6462
def to_dict(self):
6563
ret = {}
@@ -167,23 +165,10 @@ def __init__(
167165
self.audios = [AudioItem(**a) for a in audios]
168166
return
169167

170-
def free(self):
171-
for image in self.images:
172-
image.free()
173-
for audio in self.audios:
174-
audio.free()
175-
176168
def free_images(self):
177169
for image in self.images:
178170
image.free()
179171

180-
def free_audios(self):
181-
for audio in self.audios:
182-
audio.free()
183-
184-
def get_all_uuids(self):
185-
return [image.uuid for image in self.images] + [audio.uuid for audio in self.audios]
186-
187172
async def verify_and_preload(self, request: Request):
188173
for image in self.images:
189174
await image.preload(request)

lightllm/server/visualserver/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ async def loop_for_netio_req(self):
239239
self.waiting_reqs.append(recv_req)
240240
else:
241241
assert False, f"Error Req Inf {recv_req}"
242-
self.visual_recv_max_count = min(int(self.visual_recv_max_count * 1.3), 256)
242+
self.visual_recv_max_count = int(min(self.visual_recv_max_count * 1.3, 256))
243243
except zmq.ZMQError:
244244
# 当队列已经开始清空的时候,将一次接受数量下调
245245
self.visual_recv_max_count = 64

lightllm/server/visualserver/vit_connect.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def _get_vit_instance(self):
159159
self.current_vit_index = index
160160
return list(self.remote_vit_instances.values())[index]
161161

162-
async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL, free_mode: str = "all"):
162+
async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL):
163163
"""
164164
发送数据到VIT实例,支持本地和远程模式
165165
"""
@@ -174,10 +174,7 @@ async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOC
174174
if self.remote_vit:
175175
await self._wait_visual_embed_ready(req)
176176

177-
if free_mode == "all":
178-
req.multimodal_params.free()
179-
elif free_mode == "images":
180-
req.multimodal_params.free_images()
177+
req.multimodal_params.free_images()
181178

182179
async def vit_handle_loop(self):
183180
"""
@@ -223,7 +220,7 @@ async def _wait_visual_embed_ready(
223220
# 本地模式不需要等待
224221
if not self.remote_vit:
225222
return
226-
uuids = req.multimodal_params.get_all_uuids()
223+
uuids = [image.uuid for image in req.multimodal_params.images]
227224

228225
async def wait_for_embeds():
229226
while not all(self.cache_client.root.get_items_embed(uuids, True)):

lightllm/utils/start_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,4 @@ def kill_recursive(proc):
111111
logger.warning(f"Process {proc.pid} does not exist.")
112112

113113

114-
def is_multimodal_mode(args):
115-
from transformers import PretrainedConfig
116-
117-
model_cfg, _ = PretrainedConfig.get_config_dict(args.model_dir)
118-
is_multimodal = "visual" in model_cfg or "vision_config" in model_cfg
119-
return is_multimodal
120-
121-
122114
process_manager = SubmoduleManager()

0 commit comments

Comments
 (0)