-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathapi.py
More file actions
282 lines (222 loc) · 9.67 KB
/
api.py
File metadata and controls
282 lines (222 loc) · 9.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import asyncio
import importlib
import logging
import os.path
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, Optional
from anyio import CapacityLimiter
from anyio.lowlevel import RunVar
from fastapi import FastAPI, Request
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from prometheus_fastapi_instrumentator import Instrumentator
from domain import Tags, TagsStreamable
from utils import get_settings
import api.globals as cms_globals
from api.auth.db import make_sure_db_and_tables
from api.auth.users import Props
from api.dependencies import ModelServiceDep
from api.utils import add_exception_handlers, add_rate_limiter
from management.tracker_client import TrackerClient
logging.getLogger("asyncio").setLevel(logging.ERROR)
logger = logging.getLogger("cms")
def get_model_server(msd_overwritten: Optional[ModelServiceDep] = None) -> FastAPI:
app = _get_app(msd_overwritten)
config = get_settings()
logger.debug("Configuration loaded: %s", config)
add_rate_limiter(app, config)
logger.debug("Rate limiter added")
app = _load_health_check_router(app)
logger.debug("Health check router loaded")
if config.AUTH_USER_ENABLED == "true":
app = _load_auth_router(app)
logger.debug("Auth router loaded")
app = _load_model_card(app)
logger.debug("Model card router loaded")
app = _load_invocation_router(app)
logger.debug("Invocation router loaded")
if config.ENABLE_TRAINING_APIS == "true":
app = _load_supervised_training_router(app)
logger.debug("Supervised training router loaded")
if config.DISABLE_UNSUPERVISED_TRAINING != "true":
app = _load_unsupervised_training_router(app)
logger.debug("Unsupervised training router loaded")
if config.DISABLE_METACAT_TRAINING != "true":
app = _load_metacat_training_router(app)
logger.debug("Metacat training router loaded")
if config.ENABLE_EVALUATION_APIS == "true":
app = _load_evaluation_router(app)
logger.debug("Evaluation router loaded")
if config.ENABLE_PREVIEWS_APIS == "true":
app = _load_preview_router(app)
logger.debug("Preview router loaded")
return app
def get_stream_server(msd_overwritten: Optional[ModelServiceDep] = None) -> FastAPI:
app = _get_app(msd_overwritten, streamable=True)
config = get_settings()
add_rate_limiter(app, config, streamable=True)
app = _load_health_check_router(app)
logger.debug("Health check router loaded")
if config.AUTH_USER_ENABLED == "true":
app = _load_auth_router(app)
logger.debug("Auth router loaded")
app = _load_model_card(app)
logger.debug("Model card router loaded")
app = _load_stream_router(app)
logger.debug("Stream router loaded")
return app
def _get_app(
msd_overwritten: Optional[ModelServiceDep] = None, streamable: bool = False
) -> FastAPI:
tags_metadata = [
{"name": tag.name, "description": tag.value}
for tag in (Tags if not streamable else TagsStreamable)
]
config = get_settings()
app = FastAPI(
title="CogStack ModelServe",
summary="A model serving and governance system for CogStack NLP solutions",
docs_url=None,
redoc_url=None,
debug=(config.DEBUG == "true"),
openapi_tags=tags_metadata,
)
add_exception_handlers(app)
instrumentator = Instrumentator(
excluded_handlers=["/docs", "/redoc", "/metrics", "/openapi.json", "/favicon.ico", "none"]
).instrument(app)
if msd_overwritten is not None:
cms_globals.model_service_dep = msd_overwritten
cms_globals.props = Props(config.AUTH_USER_ENABLED == "true")
app.mount(
"/static",
StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")),
name="static",
)
@app.on_event("startup")
async def on_startup() -> None:
loop = asyncio.get_running_loop()
loop.set_default_executor(ThreadPoolExecutor(max_workers=50))
RunVar("_default_thread_limiter").set(CapacityLimiter(50)) # type: ignore
logger.debug("Default thread pool executor set to 50")
instrumentator.expose(app, include_in_schema=False, should_gzip=False)
logger.debug("Prometheus instrumentator metrics exposed")
if config.AUTH_USER_ENABLED == "true":
await make_sure_db_and_tables()
logger.debug("Auth database and tables are ready")
@app.get("/docs", include_in_schema=False)
async def swagger_doc(req: Request) -> HTMLResponse:
root_path = req.scope.get("root_path", "").rstrip("/")
openapi_url = root_path + app.openapi_url
oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url
if oauth2_redirect_url:
oauth2_redirect_url = root_path + oauth2_redirect_url
return get_swagger_ui_html(
openapi_url=openapi_url,
title="CogStack ModelServe",
oauth2_redirect_url=oauth2_redirect_url,
init_oauth=app.swagger_ui_init_oauth,
swagger_favicon_url="/static/images/favicon.ico",
swagger_ui_parameters=app.swagger_ui_parameters,
)
@app.get("/redoc", include_in_schema=False)
async def redoc_doc(req: Request) -> HTMLResponse:
root_path = req.scope.get("root_path", "").rstrip("/")
openapi_url = root_path + app.openapi_url
return get_redoc_html(
openapi_url=openapi_url,
title="CogStack ModelServe",
redoc_favicon_url="/static/images/favicon.ico",
)
@app.get("/", include_in_schema=False)
async def root_redirect() -> RedirectResponse:
return RedirectResponse(url="/docs")
@app.on_event("shutdown")
async def on_shutdown() -> None:
TrackerClient.end_with_interruption()
logger.debug("Tracker client terminated")
def custom_openapi() -> Dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=f"{cms_globals.model_service_dep().model_name} APIs",
version=cms_globals.model_service_dep().api_version,
description=(
"by CogStack ModelServe, a model serving and governance system for CogStack NLP"
" solutions."
),
routes=app.routes,
)
openapi_schema["info"]["x-logo"] = {
"url": "https://avatars.githubusercontent.com/u/28688163?s=200&v=4"
}
for path in openapi_schema["paths"].values():
for method_data in path.values():
if "requestBody" in method_data:
for content_type, content in method_data["requestBody"]["content"].items():
if content_type == "multipart/form-data":
schema_name = content["schema"]["$ref"].lstrip("#/components/schemas/")
schema_data = openapi_schema["components"]["schemas"].pop(schema_name)
schema_data["title"] = "UploadFile"
content["schema"] = schema_data
elif content_type == "application/x-www-form-urlencoded":
schema_name = content["schema"]["$ref"].lstrip("#/components/schemas/")
schema_data = openapi_schema["components"]["schemas"].pop(schema_name)
schema_data["title"] = "FormData"
content["schema"] = schema_data
app.openapi_schema = openapi_schema
logger.debug("Custom OpenAPI schema generated")
return app.openapi_schema
return app
def _load_auth_router(app: FastAPI) -> FastAPI:
from api.routers import authentication
importlib.reload(authentication)
app.include_router(authentication.router)
return app
def _load_model_card(app: FastAPI) -> FastAPI:
from api.routers import model_card
importlib.reload(model_card)
app.include_router(model_card.router)
return app
def _load_invocation_router(app: FastAPI) -> FastAPI:
from api.routers import invocation
importlib.reload(invocation)
app.include_router(invocation.router)
return app
def _load_supervised_training_router(app: FastAPI) -> FastAPI:
from api.routers import supervised_training
importlib.reload(supervised_training)
app.include_router(supervised_training.router)
return app
def _load_evaluation_router(app: FastAPI) -> FastAPI:
from api.routers import evaluation
importlib.reload(evaluation)
app.include_router(evaluation.router)
return app
def _load_preview_router(app: FastAPI) -> FastAPI:
from api.routers import preview
importlib.reload(preview)
app.include_router(preview.router)
return app
def _load_unsupervised_training_router(app: FastAPI) -> FastAPI:
from api.routers import unsupervised_training
importlib.reload(unsupervised_training)
app.include_router(unsupervised_training.router)
return app
def _load_metacat_training_router(app: FastAPI) -> FastAPI:
from api.routers import metacat_training
importlib.reload(metacat_training)
app.include_router(metacat_training.router)
return app
def _load_health_check_router(app: FastAPI) -> FastAPI:
from api.routers import health_check
importlib.reload(health_check)
app.include_router(health_check.router)
return app
def _load_stream_router(app: FastAPI) -> FastAPI:
from api.routers import stream
importlib.reload(stream)
app.include_router(stream.router, prefix="/stream")
return app