@@ -16,6 +16,8 @@ class InferenceClientError(Exception):
1616
1717
1818class AsyncStatus (str , Enum ):
19+ """Async status."""
20+
1921 Initialized = 'Initialized'
2022 Queue = 'Queue'
2123 Inference = 'Inference'
@@ -25,6 +27,8 @@ class AsyncStatus(str, Enum):
2527@dataclass_json (undefined = Undefined .EXCLUDE )
2628@dataclass
2729class InferenceResponse :
30+ """Inference response."""
31+
2832 headers : CaseInsensitiveDict [str ]
2933 status_code : int
3034 status_text : str
@@ -66,6 +70,7 @@ def _is_stream_response(self, headers: CaseInsensitiveDict[str]) -> bool:
6670 )
6771
6872 def output (self , is_text : bool = False ) -> Any :
73+ """Get response output as a string or object."""
6974 try :
7075 if is_text :
7176 return self ._original_response .text
@@ -99,11 +104,12 @@ def stream(self, chunk_size: int = 512, as_text: bool = True) -> Generator[Any,
99104
100105
101106class InferenceClient :
107+ """Inference client."""
108+
102109 def __init__ (
103110 self , inference_key : str , endpoint_base_url : str , timeout_seconds : int = 60 * 5
104111 ) -> None :
105- """
106- Initialize the InferenceClient.
112+ """Initialize the InferenceClient.
107113
108114 Args:
109115 inference_key: The authentication key for the API
@@ -139,17 +145,15 @@ def __exit__(self, exc_type, exc_val, exc_tb):
139145
140146 @property
141147 def global_headers (self ) -> dict [str , str ]:
142- """
143- Get the current global headers that will be used for all requests.
148+ """Get the current global headers that will be used for all requests.
144149
145150 Returns:
146151 Dictionary of current global headers
147152 """
148153 return self ._global_headers .copy ()
149154
150155 def set_global_header (self , key : str , value : str ) -> None :
151- """
152- Set or update a global header that will be used for all requests.
156+ """Set or update a global header that will be used for all requests.
153157
154158 Args:
155159 key: Header name
@@ -158,17 +162,15 @@ def set_global_header(self, key: str, value: str) -> None:
158162 self ._global_headers [key ] = value
159163
160164 def set_global_headers (self , headers : dict [str , str ]) -> None :
161- """
162- Set multiple global headers at once that will be used for all requests.
165+ """Set multiple global headers at once that will be used for all requests.
163166
164167 Args:
165168 headers: Dictionary of headers to set globally
166169 """
167170 self ._global_headers .update (headers )
168171
169172 def remove_global_header (self , key : str ) -> None :
170- """
171- Remove a global header.
173+ """Remove a global header.
172174
173175 Args:
174176 key: Header name to remove from global headers
@@ -183,8 +185,7 @@ def _build_url(self, path: str) -> str:
183185 def _build_request_headers (
184186 self , request_headers : dict [str , str ] | None = None
185187 ) -> dict [str , str ]:
186- """
187- Build the final headers by merging global headers with request-specific headers.
188+ """Build the final headers by merging global headers with request-specific headers.
188189
189190 Args:
190191 request_headers: Optional headers specific to this request
@@ -198,8 +199,7 @@ def _build_request_headers(
198199 return headers
199200
200201 def _make_request (self , method : str , path : str , ** kwargs ) -> requests .Response :
201- """
202- Make an HTTP request with error handling.
202+ """Make an HTTP request with error handling.
203203
204204 Args:
205205 method: HTTP method to use
@@ -224,7 +224,9 @@ def _make_request(self, method: str, path: str, **kwargs) -> requests.Response:
224224 response .raise_for_status ()
225225 return response
226226 except requests .exceptions .Timeout as e :
227- raise InferenceClientError (f'Request to { path } timed out after { timeout } seconds' ) from e
227+ raise InferenceClientError (
228+ f'Request to { path } timed out after { timeout } seconds'
229+ ) from e
228230 except requests .exceptions .RequestException as e :
229231 raise InferenceClientError (f'Request to { path } failed: { str (e )} ' ) from e
230232
@@ -331,6 +333,7 @@ def get(
331333 headers : dict [str , str ] | None = None ,
332334 timeout_seconds : int | None = None ,
333335 ) -> requests .Response :
336+ """Make GET request."""
334337 return self ._make_request (
335338 'GET' , path , params = params , headers = headers , timeout_seconds = timeout_seconds
336339 )
@@ -344,6 +347,7 @@ def post(
344347 headers : dict [str , str ] | None = None ,
345348 timeout_seconds : int | None = None ,
346349 ) -> requests .Response :
350+ """Make POST request."""
347351 return self ._make_request (
348352 'POST' ,
349353 path ,
@@ -363,6 +367,7 @@ def put(
363367 headers : dict [str , str ] | None = None ,
364368 timeout_seconds : int | None = None ,
365369 ) -> requests .Response :
370+ """Make PUT request."""
366371 return self ._make_request (
367372 'PUT' ,
368373 path ,
@@ -380,6 +385,7 @@ def delete(
380385 headers : dict [str , str ] | None = None ,
381386 timeout_seconds : int | None = None ,
382387 ) -> requests .Response :
388+ """Make DELETE request."""
383389 return self ._make_request (
384390 'DELETE' ,
385391 path ,
@@ -397,6 +403,7 @@ def patch(
397403 headers : dict [str , str ] | None = None ,
398404 timeout_seconds : int | None = None ,
399405 ) -> requests .Response :
406+ """Make PATCH request."""
400407 return self ._make_request (
401408 'PATCH' ,
402409 path ,
@@ -414,6 +421,7 @@ def head(
414421 headers : dict [str , str ] | None = None ,
415422 timeout_seconds : int | None = None ,
416423 ) -> requests .Response :
424+ """Make HEAD request."""
417425 return self ._make_request (
418426 'HEAD' ,
419427 path ,
@@ -429,6 +437,7 @@ def options(
429437 headers : dict [str , str ] | None = None ,
430438 timeout_seconds : int | None = None ,
431439 ) -> requests .Response :
440+ """Make OPTIONS request."""
432441 return self ._make_request (
433442 'OPTIONS' ,
434443 path ,
@@ -438,8 +447,7 @@ def options(
438447 )
439448
440449 def health (self , healthcheck_path : str = '/health' ) -> requests .Response :
441- """
442- Check the health status of the API.
450+ """Check the health status of the API.
443451
444452 Returns:
445453 requests.Response: The response from the health check
@@ -456,22 +464,23 @@ def health(self, healthcheck_path: str = '/health') -> requests.Response:
456464@dataclass_json (undefined = Undefined .EXCLUDE )
457465@dataclass
458466class AsyncInferenceExecution :
467+ """Async inference execution."""
468+
459469 _inference_client : 'InferenceClient'
460470 id : str
461471 _status : AsyncStatus
462472 INFERENCE_ID_HEADER = 'X-Inference-Id'
463473
464474 def status (self ) -> AsyncStatus :
465- """Get the current stored status of the async inference execution. Only the status value type
475+ """Get the current stored status of the async inference execution. Only the status value type.
466476
467477 Returns:
468478 AsyncStatus: The status object
469479 """
470-
471480 return self ._status
472481
473482 def status_json (self ) -> dict [str , Any ]:
474- """Get the current status of the async inference execution. Return the status json
483+ """Get the current status of the async inference execution. Return the status json.
475484
476485 Returns:
477486 Dict[str, Any]: The status response containing the execution status and other metadata
0 commit comments