|
7 | 7 | from enum import Enum, EnumMeta |
8 | 8 | from os import path |
9 | 9 | from pathlib import Path |
10 | | -from threading import Lock |
11 | 10 | from types import TracebackType |
12 | 11 | from typing import ( |
13 | 12 | TYPE_CHECKING, |
|
18 | 17 | Literal, |
19 | 18 | Optional, |
20 | 19 | Type, |
21 | | - TypeVar, |
22 | 20 | Union, |
23 | 21 | ) |
24 | 22 |
|
25 | 23 | import grpc |
26 | 24 | from google.protobuf.descriptor import EnumDescriptor |
27 | 25 |
|
28 | 26 | from ni_measurementlink_service import _datatypeinfo |
| 27 | +from ni_measurementlink_service._channelpool import ( # re-export |
| 28 | + GrpcChannelPool as GrpcChannelPool, |
| 29 | +) |
29 | 30 | from ni_measurementlink_service._internal import grpc_servicer |
30 | 31 | from ni_measurementlink_service._internal.discovery_client import DiscoveryClient |
31 | 32 | from ni_measurementlink_service._internal.parameter import ( |
32 | 33 | metadata as parameter_metadata, |
33 | 34 | ) |
34 | 35 | from ni_measurementlink_service._internal.service_manager import GrpcService |
35 | | -from ni_measurementlink_service._loggers import ClientLogger |
36 | 36 | from ni_measurementlink_service.measurement.info import ( |
37 | 37 | DataType, |
38 | 38 | MeasurementInfo, |
|
49 | 49 | else: |
50 | 50 | from typing_extensions import TypeGuard |
51 | 51 |
|
| 52 | + if sys.version_info >= (3, 11): |
| 53 | + from typing import Self |
| 54 | + else: |
| 55 | + from typing_extensions import Self |
| 56 | + |
52 | 57 | SupportedEnumType = Union[Type[Enum], _EnumTypeWrapper] |
53 | 58 |
|
54 | 59 |
|
@@ -83,67 +88,6 @@ def abort(self, code: grpc.StatusCode, details: str) -> None: |
83 | 88 | grpc_servicer.measurement_service_context.get().abort(code, details) |
84 | 89 |
|
85 | 90 |
|
86 | | -# Eventually, these can be replaced with typing.Self (Python >= 3.11). |
87 | | -_TGrpcChannelPool = TypeVar("_TGrpcChannelPool", bound="GrpcChannelPool") |
88 | | -_TMeasurementService = TypeVar("_TMeasurementService", bound="MeasurementService") |
89 | | - |
90 | | - |
91 | | -class GrpcChannelPool(object): |
92 | | - """Class that manages gRPC channel lifetimes.""" |
93 | | - |
94 | | - def __init__(self) -> None: |
95 | | - """Initialize the GrpcChannelPool object.""" |
96 | | - self._lock: Lock = Lock() |
97 | | - self._channel_cache: Dict[str, grpc.Channel] = {} |
98 | | - |
99 | | - def __enter__(self: _TGrpcChannelPool) -> _TGrpcChannelPool: |
100 | | - """Enter the runtime context of the GrpcChannelPool.""" |
101 | | - return self |
102 | | - |
103 | | - def __exit__( |
104 | | - self, |
105 | | - exc_type: Optional[Type[BaseException]], |
106 | | - exc_val: Optional[BaseException], |
107 | | - traceback: Optional[TracebackType], |
108 | | - ) -> Literal[False]: |
109 | | - """Exit the runtime context of the GrpcChannelPool.""" |
110 | | - self.close() |
111 | | - return False |
112 | | - |
113 | | - def get_channel(self, target: str) -> grpc.Channel: |
114 | | - """Return a gRPC channel. |
115 | | -
|
116 | | - Args: |
117 | | - target (str): The server address |
118 | | -
|
119 | | - """ |
120 | | - new_channel = None |
121 | | - with self._lock: |
122 | | - if target not in self._channel_cache: |
123 | | - self._lock.release() |
124 | | - new_channel = grpc.insecure_channel(target) |
125 | | - if ClientLogger.is_enabled(): |
126 | | - new_channel = grpc.intercept_channel(new_channel, ClientLogger()) |
127 | | - self._lock.acquire() |
128 | | - if target not in self._channel_cache: |
129 | | - self._channel_cache[target] = new_channel |
130 | | - new_channel = None |
131 | | - channel = self._channel_cache[target] |
132 | | - |
133 | | - # Close new_channel if it was not stored in _channel_cache. |
134 | | - if new_channel is not None: |
135 | | - new_channel.close() |
136 | | - |
137 | | - return channel |
138 | | - |
139 | | - def close(self) -> None: |
140 | | - """Close channels opened by get_channel().""" |
141 | | - with self._lock: |
142 | | - for channel in self._channel_cache.values(): |
143 | | - channel.close() |
144 | | - self._channel_cache.clear() |
145 | | - |
146 | | - |
147 | 91 | class MeasurementService: |
148 | 92 | """Class that supports registering and hosting a python function as a gRPC service. |
149 | 93 |
|
@@ -451,7 +395,7 @@ def close_service(self) -> None: |
451 | 395 | self.grpc_service.stop() |
452 | 396 | self.channel_pool.close() |
453 | 397 |
|
454 | | - def __enter__(self: _TMeasurementService) -> _TMeasurementService: |
| 398 | + def __enter__(self: Self) -> Self: |
455 | 399 | """Enter the runtime context related to the measurement service.""" |
456 | 400 | return self |
457 | 401 |
|
|
0 commit comments