Skip to content

Commit f8321f1

Browse files
authored
service: Split GrpcChannelPool into a separate submodule (#376)
* service: Split GrpcChannelPool into a separate submodule * service: Use typing_extensions.Self * service: Fix lint warnings
1 parent 9fd91b6 commit f8321f1

3 files changed

Lines changed: 96 additions & 71 deletions

File tree

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from threading import Lock
5+
from types import TracebackType
6+
from typing import (
7+
Dict,
8+
Literal,
9+
Optional,
10+
Type,
11+
TYPE_CHECKING,
12+
)
13+
14+
import grpc
15+
16+
from ni_measurementlink_service._loggers import ClientLogger
17+
18+
if TYPE_CHECKING:
19+
if sys.version_info >= (3, 11):
20+
from typing import Self
21+
else:
22+
from typing_extensions import Self
23+
24+
25+
class GrpcChannelPool(object):
26+
"""Class that manages gRPC channel lifetimes."""
27+
28+
def __init__(self) -> None:
29+
"""Initialize the GrpcChannelPool object."""
30+
self._lock: Lock = Lock()
31+
self._channel_cache: Dict[str, grpc.Channel] = {}
32+
33+
def __enter__(self: Self) -> Self:
34+
"""Enter the runtime context of the GrpcChannelPool."""
35+
return self
36+
37+
def __exit__(
38+
self,
39+
exc_type: Optional[Type[BaseException]],
40+
exc_val: Optional[BaseException],
41+
traceback: Optional[TracebackType],
42+
) -> Literal[False]:
43+
"""Exit the runtime context of the GrpcChannelPool."""
44+
self.close()
45+
return False
46+
47+
def get_channel(self, target: str) -> grpc.Channel:
48+
"""Return a gRPC channel.
49+
50+
Args:
51+
target (str): The server address
52+
53+
"""
54+
new_channel = None
55+
with self._lock:
56+
if target not in self._channel_cache:
57+
self._lock.release()
58+
new_channel = grpc.insecure_channel(target)
59+
if ClientLogger.is_enabled():
60+
new_channel = grpc.intercept_channel(new_channel, ClientLogger())
61+
self._lock.acquire()
62+
if target not in self._channel_cache:
63+
self._channel_cache[target] = new_channel
64+
new_channel = None
65+
channel = self._channel_cache[target]
66+
67+
# Close new_channel if it was not stored in _channel_cache.
68+
if new_channel is not None:
69+
new_channel.close()
70+
71+
return channel
72+
73+
def close(self) -> None:
74+
"""Close channels opened by get_channel()."""
75+
with self._lock:
76+
for channel in self._channel_cache.values():
77+
channel.close()
78+
self._channel_cache.clear()

ni_measurementlink_service/measurement/service.py

Lines changed: 9 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from enum import Enum, EnumMeta
88
from os import path
99
from pathlib import Path
10-
from threading import Lock
1110
from types import TracebackType
1211
from typing import (
1312
TYPE_CHECKING,
@@ -18,21 +17,22 @@
1817
Literal,
1918
Optional,
2019
Type,
21-
TypeVar,
2220
Union,
2321
)
2422

2523
import grpc
2624
from google.protobuf.descriptor import EnumDescriptor
2725

2826
from ni_measurementlink_service import _datatypeinfo
27+
from ni_measurementlink_service._channelpool import ( # re-export
28+
GrpcChannelPool as GrpcChannelPool,
29+
)
2930
from ni_measurementlink_service._internal import grpc_servicer
3031
from ni_measurementlink_service._internal.discovery_client import DiscoveryClient
3132
from ni_measurementlink_service._internal.parameter import (
3233
metadata as parameter_metadata,
3334
)
3435
from ni_measurementlink_service._internal.service_manager import GrpcService
35-
from ni_measurementlink_service._loggers import ClientLogger
3636
from ni_measurementlink_service.measurement.info import (
3737
DataType,
3838
MeasurementInfo,
@@ -49,6 +49,11 @@
4949
else:
5050
from typing_extensions import TypeGuard
5151

52+
if sys.version_info >= (3, 11):
53+
from typing import Self
54+
else:
55+
from typing_extensions import Self
56+
5257
SupportedEnumType = Union[Type[Enum], _EnumTypeWrapper]
5358

5459

@@ -83,67 +88,6 @@ def abort(self, code: grpc.StatusCode, details: str) -> None:
8388
grpc_servicer.measurement_service_context.get().abort(code, details)
8489

8590

86-
# Eventually, these can be replaced with typing.Self (Python >= 3.11).
87-
_TGrpcChannelPool = TypeVar("_TGrpcChannelPool", bound="GrpcChannelPool")
88-
_TMeasurementService = TypeVar("_TMeasurementService", bound="MeasurementService")
89-
90-
91-
class GrpcChannelPool(object):
92-
"""Class that manages gRPC channel lifetimes."""
93-
94-
def __init__(self) -> None:
95-
"""Initialize the GrpcChannelPool object."""
96-
self._lock: Lock = Lock()
97-
self._channel_cache: Dict[str, grpc.Channel] = {}
98-
99-
def __enter__(self: _TGrpcChannelPool) -> _TGrpcChannelPool:
100-
"""Enter the runtime context of the GrpcChannelPool."""
101-
return self
102-
103-
def __exit__(
104-
self,
105-
exc_type: Optional[Type[BaseException]],
106-
exc_val: Optional[BaseException],
107-
traceback: Optional[TracebackType],
108-
) -> Literal[False]:
109-
"""Exit the runtime context of the GrpcChannelPool."""
110-
self.close()
111-
return False
112-
113-
def get_channel(self, target: str) -> grpc.Channel:
114-
"""Return a gRPC channel.
115-
116-
Args:
117-
target (str): The server address
118-
119-
"""
120-
new_channel = None
121-
with self._lock:
122-
if target not in self._channel_cache:
123-
self._lock.release()
124-
new_channel = grpc.insecure_channel(target)
125-
if ClientLogger.is_enabled():
126-
new_channel = grpc.intercept_channel(new_channel, ClientLogger())
127-
self._lock.acquire()
128-
if target not in self._channel_cache:
129-
self._channel_cache[target] = new_channel
130-
new_channel = None
131-
channel = self._channel_cache[target]
132-
133-
# Close new_channel if it was not stored in _channel_cache.
134-
if new_channel is not None:
135-
new_channel.close()
136-
137-
return channel
138-
139-
def close(self) -> None:
140-
"""Close channels opened by get_channel()."""
141-
with self._lock:
142-
for channel in self._channel_cache.values():
143-
channel.close()
144-
self._channel_cache.clear()
145-
146-
14791
class MeasurementService:
14892
"""Class that supports registering and hosting a python function as a gRPC service.
14993
@@ -451,7 +395,7 @@ def close_service(self) -> None:
451395
self.grpc_service.stop()
452396
self.channel_pool.close()
453397

454-
def __enter__(self: _TMeasurementService) -> _TMeasurementService:
398+
def __enter__(self: Self) -> Self:
455399
"""Enter the runtime context related to the measurement service."""
456400
return self
457401

ni_measurementlink_service/session_management.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from __future__ import annotations
33

44
import abc
5+
import sys
56
import warnings
67
from functools import cached_property
78
from types import TracebackType
89
from typing import (
10+
TYPE_CHECKING,
911
Any,
1012
Iterable,
1113
List,
@@ -14,7 +16,6 @@
1416
Optional,
1517
Sequence,
1618
Type,
17-
TypeVar,
1819
)
1920

2021
import grpc
@@ -29,6 +30,12 @@
2930
session_management_service_pb2_grpc,
3031
)
3132

33+
if TYPE_CHECKING:
34+
if sys.version_info >= (3, 11):
35+
from typing import Self
36+
else:
37+
from typing_extensions import Self
38+
3239
GRPC_SERVICE_INTERFACE_NAME = "ni.measurementlink.sessionmanagement.v1.SessionManagementService"
3340
GRPC_SERVICE_CLASS = "ni.measurementlink.sessionmanagement.v1.SessionManagementService"
3441

@@ -175,10 +182,6 @@ def _convert_session_info_to_grpc(
175182
)
176183

177184

178-
# Eventually, this can be replaced with typing.Self (Python >= 3.11).
179-
_TReservation = TypeVar("_TReservation", bound="BaseReservation")
180-
181-
182185
class BaseReservation(abc.ABC):
183186
"""Manages session reservation."""
184187

@@ -191,7 +194,7 @@ def __init__(
191194
self._session_manager = session_manager
192195
self._session_info = session_info
193196

194-
def __enter__(self: _TReservation) -> _TReservation:
197+
def __enter__(self: Self) -> Self:
195198
"""Context management protocol. Returns self."""
196199
return self
197200

0 commit comments

Comments
 (0)