-
Notifications
You must be signed in to change notification settings - Fork 321
Expand file tree
/
Copy pathapi_http.py
More file actions
executable file
·392 lines (323 loc) · 14.4 KB
/
api_http.py
File metadata and controls
executable file
·392 lines (323 loc) · 14.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
# Adapted from vllm/entrypoints/api_server.py
# of the vllm-project/vllm GitHub repository.
#
# Copyright 2023 ModelTC Team
# Copyright 2023 vLLM Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import collections
import time
import uvloop
import requests
import base64
import os
from io import BytesIO
import pickle
import setproctitle
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
import ujson as json
from http import HTTPStatus
import uuid
from PIL import Image
import multiprocessing as mp
from typing import AsyncGenerator, Union
from typing import Callable
from lightllm.server import TokenLoad
from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import Response, StreamingResponse, JSONResponse
from lightllm.server.core.objs.sampling_params import SamplingParams
from lightllm.server.core.objs import StartArgs
from .multimodal_params import MultimodalParams
from .httpserver.manager import HttpServerManager
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
from .api_lightllm import lightllm_get_score
from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size
from lightllm.utils.log_utils import init_logger
from lightllm.utils.error_utils import ServerBusyError
from lightllm.server.metrics.manager import MetricClient
from lightllm.utils.envs_utils import get_unique_server_name
from dataclasses import dataclass
from .api_openai import chat_completions_impl, completions_impl
from .api_models import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
ModelCard,
ModelListResponse,
)
from .build_prompt import build_prompt, init_tokenizer
logger = init_logger(__name__)
@dataclass
class G_Objs:
app: FastAPI = None
metric_client: MetricClient = None
args: StartArgs = None
g_generate_func: Callable = None
g_generate_stream_func: Callable = None
httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None
shared_token_load: TokenLoad = None
# OpenAI-compatible "created" timestamp for /v1/models.
# Should be stable for the lifetime of this server process.
model_created: int = None
def set_args(self, args: StartArgs):
self.args = args
from .api_lightllm import lightllm_generate, lightllm_generate_stream
from .api_tgi import tgi_generate_impl, tgi_generate_stream_impl
if args.use_tgi_api:
self.g_generate_func = tgi_generate_impl
self.g_generate_stream_func = tgi_generate_stream_impl
else:
self.g_generate_func = lightllm_generate
self.g_generate_stream_func = lightllm_generate_stream
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::api_server")
if args.run_mode == "pd_master":
self.metric_client = MetricClient(args.metric_port)
self.httpserver_manager = HttpServerManagerForPDMaster(
args=args,
)
else:
init_tokenizer(args) # for openai api
SamplingParams.load_generation_cfg(args.model_dir)
CompletionRequest.load_generation_cfg(args.model_dir)
ChatCompletionRequest.load_generation_cfg(args.model_dir)
self.metric_client = MetricClient(args.metric_port)
self.httpserver_manager = HttpServerManager(args=args)
dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node)
if self.model_created is None:
self.model_created = int(time.time())
g_objs = G_Objs()
app = FastAPI()
g_objs.app = app
def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
g_objs.metric_client.counter_inc("lightllm_request_failure")
return JSONResponse({"message": message}, status_code=status_code.value)
@app.get("/liveness")
@app.post("/liveness")
def liveness():
return {"status": "ok"}
@app.get("/readiness")
@app.post("/readiness")
def readiness():
return {"status": "ok"}
@app.get("/get_model_name")
@app.post("/get_model_name")
def get_model_name():
return {"model_name": g_objs.args.model_name}
@app.get("/healthz", summary="Check server health")
@app.get("/health", summary="Check server health")
@app.head("/health", summary="Check server health")
async def healthcheck(request: Request):
if g_objs.args.run_mode == "pd_master":
return JSONResponse({"message": "Ok"}, status_code=200)
if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true":
return JSONResponse({"message": "Error"}, status_code=503)
from lightllm.utils.health_check import health_check, health_obj
health_task = asyncio.create_task(health_check(g_objs.args, g_objs.httpserver_manager, None))
if not health_obj.is_health():
await health_task
return JSONResponse(
{"message": "Ok" if health_obj.is_health() else "Error"}, status_code=200 if health_obj.is_health() else 503
)
@app.get("/token_load", summary="Get the current server's load of tokens")
async def token_load(request: Request):
ans_dict = {
# 当前使用 token 量,估计的负载
"current_load": [
float(g_objs.shared_token_load.get_current_load(dp_index)) for dp_index in range(g_objs.args.dp)
],
# 朴素估计的负载,简单将当前请求的输入和输出长度想加得到,目前已未使用,其值与 dynamic_max_load 一样。
"logical_max_load": [
float(g_objs.shared_token_load.get_logical_max_load(dp_index)) for dp_index in range(g_objs.args.dp)
],
# 动态估计的最大负载,考虑请求中途退出的情况的负载
"dynamic_max_load": [
float(g_objs.shared_token_load.get_dynamic_max_load(dp_index)) for dp_index in range(g_objs.args.dp)
],
}
if g_objs.args.dp == 1:
ans_dict = {k: v[0] for k, v in ans_dict.items()}
return JSONResponse(ans_dict, status_code=200)
@app.post("/generate")
async def generate(request: Request) -> Response:
if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
return create_error_response(
HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface"
)
try:
return await g_objs.g_generate_func(request, g_objs.httpserver_manager)
except ServerBusyError as e:
logger.error("%s", str(e), exc_info=True)
return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e))
except Exception as e:
logger.error("An error occurred: %s", str(e), exc_info=True)
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))
@app.post("/generate_stream")
async def generate_stream(request: Request) -> Response:
if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
return create_error_response(
HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface"
)
try:
return await g_objs.g_generate_stream_func(request, g_objs.httpserver_manager)
except ServerBusyError as e:
logger.error("%s", str(e), exc_info=True)
return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e))
except Exception as e:
logger.error("An error occurred: %s", str(e), exc_info=True)
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))
@app.post("/get_score")
async def get_score(request: Request) -> Response:
if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
return create_error_response(
HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface"
)
try:
return await lightllm_get_score(request, g_objs.httpserver_manager)
except Exception as e:
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))
@app.post("/")
async def compat_generate(request: Request) -> Response:
if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
return create_error_response(
HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface"
)
request_dict = await request.json()
stream = request_dict.pop("stream", False)
if stream:
return await generate_stream(request)
else:
return await generate(request)
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response:
if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
return create_error_response(
HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface"
)
resp = await chat_completions_impl(request, raw_request)
return resp
@app.post("/v1/completions", response_model=CompletionResponse)
async def completions(request: CompletionRequest, raw_request: Request) -> Response:
if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
return create_error_response(
HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface"
)
resp = await completions_impl(request, raw_request)
return resp
@app.get("/v1/models", response_model=ModelListResponse)
@app.post("/v1/models", response_model=ModelListResponse)
async def get_models(raw_request: Request):
model_name = g_objs.args.model_name
max_model_len = g_objs.args.max_req_total_len
if model_name == "default_model_name" and g_objs.args.model_dir:
model_name = os.path.basename(g_objs.args.model_dir.rstrip("/"))
return ModelListResponse(
data=[
ModelCard(
id=model_name,
created=g_objs.model_created,
max_model_len=max_model_len,
owned_by=g_objs.args.model_owner,
)
]
)
@app.get("/tokens")
@app.post("/tokens")
async def tokens(request: Request):
try:
request_dict = await request.json()
prompt = request_dict.pop("text")
sample_params_dict = request_dict.pop("parameters", {})
sampling_params = SamplingParams()
sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sample_params_dict)
sampling_params.verify()
multimodal_params_dict = request_dict.get("multimodal_params", {})
multimodal_params = MultimodalParams(**multimodal_params_dict)
await multimodal_params.verify_and_preload(request)
return JSONResponse(
{
"ntokens": g_objs.httpserver_manager.tokens(
prompt, multimodal_params, sampling_params, sample_params_dict
)
},
status_code=200,
)
except Exception as e:
return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}")
@app.get("/metrics")
async def metrics() -> Response:
data = await g_objs.metric_client.generate_latest()
response = Response(data)
response.mimetype = "text/plain"
return response
@app.websocket("/pd_register")
async def register_and_keep_alive(websocket: WebSocket):
await websocket.accept()
websocket._receive_bytes_max_size = get_lightllm_websocket_max_message_size()
client_ip, client_port = websocket.client
logger.info(f"Client connected from IP: {client_ip}, Port: {client_port}")
regist_json = json.loads(await websocket.receive_text())
logger.info(f"received regist_json {regist_json}")
await g_objs.httpserver_manager.register_pd(regist_json, websocket)
try:
while True:
# 等待接收消息,设置超时为10秒
data = await websocket.receive_bytes()
obj = pickle.loads(data)
await g_objs.httpserver_manager.put_to_handle_queue(obj)
except (WebSocketDisconnect, Exception, RuntimeError) as e:
logger.error(f"client {regist_json} has error {str(e)}")
logger.exception(str(e))
finally:
logger.error(f"client {regist_json} removed")
await g_objs.httpserver_manager.remove_pd(regist_json)
return
@app.websocket("/kv_move_status")
async def kv_move_status(websocket: WebSocket):
await websocket.accept()
client_ip, client_port = websocket.client
logger.info(f"kv_move_status Client connected from IP: {client_ip}, Port: {client_port}")
try:
while True:
# 等待接收消息,设置超时为10秒
data = await websocket.receive_bytes()
upkv_status = pickle.loads(data)
logger.info(f"received upkv_status {upkv_status} from {(client_ip, client_port)}")
await g_objs.httpserver_manager.update_req_status(upkv_status)
except (WebSocketDisconnect, Exception, RuntimeError) as e:
logger.error(f"kv_move_status client {(client_ip, client_port)} has error {str(e)}")
logger.exception(str(e))
return
@app.on_event("shutdown")
async def shutdown():
logger.info("Received signal to shutdown. Performing graceful shutdown...")
await asyncio.sleep(3)
# 杀掉所有子进程
import psutil
import signal
parent = psutil.Process(os.getpid())
children = parent.children(recursive=True)
for child in children:
os.kill(child.pid, signal.SIGKILL)
logger.info("Graceful shutdown completed.")
return
@app.on_event("startup")
async def startup_event():
logger.info("server start up")
loop = asyncio.get_event_loop()
g_objs.set_args(get_env_start_args())
loop.create_task(g_objs.httpserver_manager.handle_loop())
logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}")
return