This repository was archived by the owner on Jan 23, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathbase.py
More file actions
149 lines (113 loc) · 4.73 KB
/
base.py
File metadata and controls
149 lines (113 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
Base classes for drivers and driver clients
"""
from __future__ import annotations
from contextlib import ExitStack, contextmanager
from dataclasses import field
from anyio.from_thread import BlockingPortal
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
from .core import AsyncDriverClient
from jumpstarter.common.importlib import _format_missing_driver_message
from jumpstarter.streams.blocking import BlockingStream
@dataclass(kw_only=True, config=ConfigDict(arbitrary_types_allowed=True))
class DriverClient(AsyncDriverClient):
"""Base class for driver clients
Client methods can be implemented as regular functions,
and call the `call` or `streamingcall` helpers internally
to invoke exported methods on the driver.
Additional client functionalities such as raw stream
connections or sharing client-side resources can be added
by inheriting mixin classes under `jumpstarter.drivers.mixins`
"""
children: dict[str, DriverClient] = field(default_factory=dict)
portal: BlockingPortal
stack: ExitStack
description: str | None = None
"""Driver description from GetReport(), used for CLI help text"""
methods_description: dict[str, str] = field(default_factory=dict)
"""Map of method names to their help descriptions from GetReport()"""
def call(self, method, *args):
"""
Invoke driver call
:param str method: method name of driver call
:param list[Any] args: arguments for driver call
:return: driver call result
:rtype: Any
"""
return self.portal.call(self.call_async, method, *args)
def streamingcall(self, method, *args):
"""
Invoke streaming driver call
:param str method: method name of streaming driver call
:param list[Any] args: arguments for streaming driver call
:return: streaming driver call result
:rtype: Generator[Any, None, None]
"""
generator = self.portal.call(self.streamingcall_async, method, *args)
while True:
try:
yield self.portal.call(generator.__anext__)
except StopAsyncIteration:
break
@contextmanager
def stream(self, method="connect"):
"""
Open a blocking stream session with a context manager.
:param str method: method name of streaming driver call
:return: blocking stream session object context manager.
"""
with self.portal.wrap_async_context_manager(self.stream_async(method)) as stream:
yield BlockingStream(stream=stream, portal=self.portal)
@contextmanager
def log_stream(self):
with self.portal.wrap_async_context_manager(self.log_stream_async()):
yield
def open_stream(self) -> BlockingStream:
"""
Open a blocking stream session without a context manager.
:return: blocking stream session object.
:rtype: BlockingStream
"""
return self.stack.enter_context(self.stream())
def close(self):
"""
Close the open stream session without a context manager.
"""
self.stack.close()
def __del__(self):
self.close()
@dataclass(kw_only=True, config=ConfigDict(arbitrary_types_allowed=True))
class StubDriverClient(DriverClient):
"""Stub client for drivers that are not installed.
This client is created when a driver client class cannot be imported.
It provides a placeholder that raises a clear error when the driver
is actually used.
"""
def _get_missing_class_path(self) -> str:
"""Get the missing class path from labels."""
return self.labels["jumpstarter.dev/client"]
def _raise_missing_error(self):
"""Raise ImportError with installation instructions."""
class_path = self._get_missing_class_path()
message = _format_missing_driver_message(class_path)
raise ImportError(message)
def call(self, method, *args):
"""Invoke driver call - raises ImportError since driver is not installed."""
self._raise_missing_error()
def streamingcall(self, method, *args):
"""Invoke streaming driver call - raises ImportError since driver is not installed."""
self._raise_missing_error()
# Unreachable yield to make this a generator function for type checking
while False: # noqa: SIM114
yield
@contextmanager
def stream(self, method="connect"):
"""Open a stream - raises ImportError since driver is not installed."""
self._raise_missing_error()
yield
@contextmanager
def log_stream(self):
"""Open a log stream - raises ImportError since driver is not installed."""
self._raise_missing_error()
yield