-
Notifications
You must be signed in to change notification settings - Fork 259
Fix: Improve HTTP API structure and async handler usage (#569) #2063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| from fastapi import APIRouter, Depends | ||
| from typing import Any, Optional | ||
| from starlette.responses import JSONResponse | ||
|
|
||
| from inference.core.env import DOCKER_SOCKET_PATH | ||
| from inference.core.managers.metrics import get_container_stats | ||
| from inference.core.utils.container import is_docker_socket_mounted | ||
|
|
||
|
|
||
| def create_health_router(model_init_state: Optional[Any] = None) -> APIRouter: | ||
| router = APIRouter() | ||
|
|
||
| @router.get("/device/stats", summary="Device/container statistics") | ||
| def device_stats(): | ||
| not_configured_error_message = { | ||
| "error": "Device statistics endpoint is not enabled.", | ||
| "hint": ( | ||
| "Mount the Docker socket and point its location when running the docker " | ||
| "container to collect device stats " | ||
| "(i.e. `docker run ... -v /var/run/docker.sock:/var/run/docker.sock " | ||
| "-e DOCKER_SOCKET_PATH=/var/run/docker.sock ...`)." | ||
| ), | ||
| } | ||
| if not DOCKER_SOCKET_PATH: | ||
| return JSONResponse( | ||
| status_code=404, | ||
| content=not_configured_error_message, | ||
| ) | ||
| if not is_docker_socket_mounted(docker_socket_path=DOCKER_SOCKET_PATH): | ||
| return JSONResponse( | ||
| status_code=500, | ||
| content=not_configured_error_message, | ||
| ) | ||
|
|
||
| container_stats = get_container_stats(docker_socket_path=DOCKER_SOCKET_PATH) | ||
| return JSONResponse(status_code=200, content=container_stats) | ||
|
|
||
| @router.get("/readiness", status_code=200) | ||
| def readiness(state: Any = Depends(lambda: model_init_state)): | ||
| """Readiness endpoint for Kubernetes readiness probe.""" | ||
| if state is None: | ||
| return {"status": "ready"} | ||
| with state.lock: | ||
| if state.is_ready: | ||
| return {"status": "ready"} | ||
| return JSONResponse( | ||
| content={"status": "not ready"}, status_code=503 | ||
| ) | ||
|
|
||
| @router.get("/healthz", status_code=200) | ||
| def healthz(): | ||
| """Health endpoint for Kubernetes liveness probe.""" | ||
| return {"status": "healthy"} | ||
|
|
||
| return router | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,208 @@ | ||||||
| from typing import List, Optional, Union | ||||||
|
|
||||||
| from fastapi import APIRouter, BackgroundTasks | ||||||
|
|
||||||
| from inference.core import logger | ||||||
| from inference.core.entities.requests.inference import ( | ||||||
| ClassificationInferenceRequest, | ||||||
| DepthEstimationRequest, | ||||||
| InferenceRequest, | ||||||
| InstanceSegmentationInferenceRequest, | ||||||
| KeypointsDetectionInferenceRequest, | ||||||
| ObjectDetectionInferenceRequest, | ||||||
| ) | ||||||
| from inference.core.entities.responses.inference import ( | ||||||
| ClassificationInferenceResponse, | ||||||
| DepthEstimationResponse, | ||||||
| InferenceResponse, | ||||||
| InstanceSegmentationInferenceResponse, | ||||||
| KeypointsDetectionInferenceResponse, | ||||||
| ObjectDetectionInferenceResponse, | ||||||
| MultiLabelClassificationInferenceResponse, | ||||||
| StubResponse, | ||||||
| ) | ||||||
| from inference.core.env import DEPTH_ESTIMATION_ENABLED | ||||||
| from inference.core.interfaces.http.error_handlers import with_route_exceptions | ||||||
| from inference.core.interfaces.http.orjson_utils import orjson_response | ||||||
| from inference.core.managers.base import ModelManager | ||||||
| from inference.core.utils.model_alias import resolve_roboflow_model_alias | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was moved |
||||||
| from inference.usage_tracking.collector import usage_collector | ||||||
|
|
||||||
|
|
||||||
| def create_inference_router( | ||||||
| model_manager: ModelManager, | ||||||
| ) -> APIRouter: | ||||||
| router = APIRouter() | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whit this modular approach we could also improve the docs by doing: This would group the doc endpoints nicely |
||||||
|
|
||||||
| def process_inference_request( | ||||||
| inference_request: InferenceRequest, | ||||||
| countinference: Optional[bool] = None, | ||||||
| service_secret: Optional[str] = None, | ||||||
| **kwargs, | ||||||
| ) -> InferenceResponse: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From what I see, the docstrings were removed. Maybe it's a good opportunity to add the docstrings for the public functions in the modules you created. WDYT? |
||||||
| de_aliased_model_id = resolve_roboflow_model_alias( | ||||||
| model_id=inference_request.model_id | ||||||
| ) | ||||||
| model_manager.add_model( | ||||||
| de_aliased_model_id, | ||||||
| inference_request.api_key, | ||||||
| countinference=countinference, | ||||||
| service_secret=service_secret, | ||||||
| ) | ||||||
| resp = model_manager.infer_from_request_sync( | ||||||
| de_aliased_model_id, | ||||||
| inference_request, | ||||||
| **kwargs, | ||||||
| ) | ||||||
| return orjson_response(resp) | ||||||
|
|
||||||
| @router.post( | ||||||
| "/infer/object_detection", | ||||||
| response_model=Union[ | ||||||
| ObjectDetectionInferenceResponse, | ||||||
| List[ObjectDetectionInferenceResponse], | ||||||
| StubResponse, | ||||||
| ], | ||||||
| summary="Object detection infer", | ||||||
| description="Run inference with the specified object detection model", | ||||||
| response_model_exclude_none=True, | ||||||
| ) | ||||||
| @with_route_exceptions | ||||||
| @usage_collector("request") | ||||||
| def infer_object_detection( | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned previously let's bring back docstrings to the functions |
||||||
| inference_request: ObjectDetectionInferenceRequest, | ||||||
| background_tasks: BackgroundTasks, | ||||||
| countinference: Optional[bool] = None, | ||||||
| service_secret: Optional[str] = None, | ||||||
| ): | ||||||
| logger.debug("Reached /infer/object_detection") | ||||||
| return process_inference_request( | ||||||
| inference_request, | ||||||
| active_learning_eligible=True, | ||||||
| background_tasks=background_tasks, | ||||||
| countinference=countinference, | ||||||
| service_secret=service_secret, | ||||||
| ) | ||||||
|
|
||||||
| @router.post( | ||||||
| "/infer/instance_segmentation", | ||||||
| response_model=Union[InstanceSegmentationInferenceResponse, StubResponse], | ||||||
| summary="Instance segmentation infer", | ||||||
| description="Run inference with the specified instance segmentation model", | ||||||
| ) | ||||||
| @with_route_exceptions | ||||||
| @usage_collector("request") | ||||||
| def infer_instance_segmentation( | ||||||
| inference_request: InstanceSegmentationInferenceRequest, | ||||||
| background_tasks: BackgroundTasks, | ||||||
| countinference: Optional[bool] = None, | ||||||
| service_secret: Optional[str] = None, | ||||||
| ): | ||||||
| logger.debug("Reached /infer/instance_segmentation") | ||||||
| return process_inference_request( | ||||||
| inference_request, | ||||||
| active_learning_eligible=True, | ||||||
| background_tasks=background_tasks, | ||||||
| countinference=countinference, | ||||||
| service_secret=service_secret, | ||||||
| ) | ||||||
|
|
||||||
| @router.post( | ||||||
| "/infer/semantic_segmentation", | ||||||
| response_model=Union[InstanceSegmentationInferenceResponse, StubResponse], | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| summary="Semantic segmentation infer", | ||||||
| description="Run inference with the specified semantic segmentation model", | ||||||
| ) | ||||||
| @with_route_exceptions | ||||||
| @usage_collector("request") | ||||||
| def infer_semantic_segmentation( | ||||||
| inference_request, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing type hint, please double check the code to see if there are other places like that. |
||||||
| background_tasks: BackgroundTasks, | ||||||
| countinference: Optional[bool] = None, | ||||||
| service_secret: Optional[str] = None, | ||||||
| ): | ||||||
| logger.debug("Reached /infer/semantic_segmentation") | ||||||
| return process_inference_request( | ||||||
| inference_request, | ||||||
| active_learning_eligible=True, | ||||||
| background_tasks=background_tasks, | ||||||
| countinference=countinference, | ||||||
| service_secret=service_secret, | ||||||
| ) | ||||||
|
|
||||||
| @router.post( | ||||||
| "/infer/classification", | ||||||
| response_model=Union[ | ||||||
| ClassificationInferenceResponse, | ||||||
| MultiLabelClassificationInferenceResponse, | ||||||
| StubResponse, | ||||||
| ], | ||||||
| summary="Classification infer", | ||||||
| description="Run inference with the specified classification model", | ||||||
| ) | ||||||
| @with_route_exceptions | ||||||
| @usage_collector("request") | ||||||
| def infer_classification( | ||||||
| inference_request: ClassificationInferenceRequest, | ||||||
| background_tasks: BackgroundTasks, | ||||||
| countinference: Optional[bool] = None, | ||||||
| service_secret: Optional[str] = None, | ||||||
| ): | ||||||
| logger.debug("Reached /infer/classification") | ||||||
| return process_inference_request( | ||||||
| inference_request, | ||||||
| active_learning_eligible=True, | ||||||
| background_tasks=background_tasks, | ||||||
| countinference=countinference, | ||||||
| service_secret=service_secret, | ||||||
| ) | ||||||
|
|
||||||
| @router.post( | ||||||
| "/infer/keypoints_detection", | ||||||
| response_model=Union[KeypointsDetectionInferenceResponse, StubResponse], | ||||||
| summary="Keypoints detection infer", | ||||||
| description="Run inference with the specified keypoints detection model", | ||||||
| ) | ||||||
| @with_route_exceptions | ||||||
| @usage_collector("request") | ||||||
| def infer_keypoints( | ||||||
| inference_request: KeypointsDetectionInferenceRequest, | ||||||
| countinference: Optional[bool] = None, | ||||||
| service_secret: Optional[str] = None, | ||||||
| ): | ||||||
| logger.debug("Reached /infer/keypoints_detection") | ||||||
| return process_inference_request( | ||||||
| inference_request, | ||||||
| countinference=countinference, | ||||||
| service_secret=service_secret, | ||||||
| ) | ||||||
|
|
||||||
| if DEPTH_ESTIMATION_ENABLED: | ||||||
|
|
||||||
| @router.post( | ||||||
| "/infer/depth-estimation", | ||||||
| response_model=DepthEstimationResponse, | ||||||
| summary="Depth Estimation", | ||||||
| description="Run the depth estimation model to generate a depth map.", | ||||||
| ) | ||||||
| @with_route_exceptions | ||||||
| def depth_estimation( | ||||||
| inference_request: DepthEstimationRequest, | ||||||
| countinference: Optional[bool] = None, | ||||||
| service_secret: Optional[str] = None, | ||||||
| ): | ||||||
| logger.debug("Reached /infer/depth-estimation") | ||||||
| depth_model_id = inference_request.model_id | ||||||
| model_manager.add_model( | ||||||
| depth_model_id, | ||||||
| inference_request.api_key, | ||||||
| countinference=countinference, | ||||||
| service_secret=service_secret, | ||||||
| ) | ||||||
| response = model_manager.infer_from_request_sync( | ||||||
| depth_model_id, inference_request | ||||||
| ) | ||||||
| return response | ||||||
|
|
||||||
| return router | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,68 @@ | ||||||
| from typing import Optional | ||||||
|
|
||||||
| from fastapi import APIRouter, HTTPException, Query | ||||||
|
|
||||||
| rom inference.core.version import __version__ | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| from inference.core.devices.utils import GLOBAL_INFERENCE_SERVER_ID | ||||||
| from inference.core.entities.responses.server_state import ServerVersionInfo | ||||||
|
|
||||||
|
|
||||||
| def create_info_router() -> APIRouter: | ||||||
| router = APIRouter() | ||||||
|
|
||||||
| @router.get( | ||||||
| "/info", | ||||||
| response_model=ServerVersionInfo, | ||||||
| summary="Info", | ||||||
| description="Get the server name and version number", | ||||||
| ) | ||||||
| def root(): | ||||||
| """Endpoint to get the server name and version number. | ||||||
|
|
||||||
| Returns: | ||||||
| ServerVersionInfo: The server version information. | ||||||
| """ | ||||||
| return ServerVersionInfo( | ||||||
| name="Roboflow Inference Server", | ||||||
| version=__version__, | ||||||
| uuid=GLOBAL_INFERENCE_SERVER_ID, | ||||||
| ) | ||||||
|
|
||||||
| @router.get( | ||||||
| "/logs", | ||||||
| summary="Get Recent Logs", | ||||||
| description="Get recent application logs for debugging", | ||||||
| ) | ||||||
| def get_logs( | ||||||
| limit: Optional[int] = Query( | ||||||
| 100, description="Maximum number of log entries to return" | ||||||
| ), | ||||||
| level: Optional[str] = Query( | ||||||
| None, | ||||||
| description="Filter by log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", | ||||||
| ), | ||||||
| since: Optional[str] = Query( | ||||||
| None, description="Return logs since this ISO timestamp" | ||||||
| ), | ||||||
| ): | ||||||
| """Only available when ENABLE_IN_MEMORY_LOGS is set to 'true'.""" | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstrings and FastAPI param/endpoint descriptions have two different roles. One is providing input for the API documentations the other directly commenting code (helpful when using an IDE). So let's leave both. Unfortunately we have inconsistencies throughout this API. As for example here we are missing any response description through |
||||||
| from inference.core.logging.memory_handler import ( | ||||||
| get_recent_logs, | ||||||
| is_memory_logging_enabled, | ||||||
| ) | ||||||
|
|
||||||
| if not is_memory_logging_enabled(): | ||||||
| raise HTTPException( | ||||||
| status_code=404, detail="Logs endpoint not available" | ||||||
| ) | ||||||
|
|
||||||
| try: | ||||||
| logs = get_recent_logs(limit=limit or 100, level=level, since=since) | ||||||
| return {"logs": logs, "total_count": len(logs)} | ||||||
| except (ImportError, ModuleNotFoundError): | ||||||
| raise HTTPException( | ||||||
| status_code=500, detail="Logging system not properly initialized" | ||||||
| ) | ||||||
|
|
||||||
| return router | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always add a new line at the end of the file