-
Notifications
You must be signed in to change notification settings - Fork 49
Expand file tree
/
Copy pathgateway.py
More file actions
242 lines (184 loc) · 5.8 KB
/
gateway.py
File metadata and controls
242 lines (184 loc) · 5.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
from __future__ import annotations
import dataclasses
import logging
from io import BytesIO
from operator import attrgetter
from typing import SupportsBytes
from typing import TypeVar
from databento_dbn import Compression
from databento_dbn import Encoding
from databento_dbn import Schema
from databento_dbn import SType
from databento.common.enums import SlowReaderBehavior
from databento.common.publishers import Dataset
from databento.common.system import USER_AGENT
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class GatewayControl(SupportsBytes):
"""
Base class for gateway control messages.
"""
GC = TypeVar("GC", bound="GatewayControl")
@classmethod
def parse(cls: type[GC], line: str | bytes) -> GC:
"""
Parse a `GatewayControl` message from a string.
Parameters
----------
line : str | bytes
The data to parse into a GatewayControl message.
Returns
-------
T
Raises
------
ValueError
If the line fails to parse.
"""
if isinstance(line, bytes):
line = line.decode("utf-8")
if not line.endswith("\n"):
raise ValueError(f"'{line!r}' does not end with a newline")
split_tokens = [t.partition("=") for t in line.strip().split("|")]
data_dict = {k: v for k, _, v in split_tokens}
try:
return cls(**data_dict)
except TypeError:
raise ValueError(
f"'{line!r}'is not a parseable {cls.__name__}",
) from None
def __str__(self) -> str:
fields = tuple(map(attrgetter("name"), dataclasses.fields(self)))
values = tuple(getattr(self, f) for f in fields)
tokens = "|".join(f"{k}={v}" for k, v in zip(fields, values) if v is not None)
return f"{tokens}\n"
def __bytes__(self) -> bytes:
return str(self).encode("utf-8")
@dataclasses.dataclass
class Greeting(GatewayControl):
"""
A greeting message is sent by the gateway upon connection.
"""
lsg_version: str
@dataclasses.dataclass
class ChallengeRequest(GatewayControl):
"""
A challenge request is sent by the gateway upon connection.
"""
cram: str
@dataclasses.dataclass
class AuthenticationResponse(GatewayControl):
"""
An authentication response is sent by the gateway after a valid
authentication request is sent to the gateway.
"""
success: str
error: str | None = None
session_id: str | None = None
@dataclasses.dataclass
class AuthenticationRequest(GatewayControl):
"""
An authentication request is sent to the gateway after a challenge response
is received.
This is required to authenticate a user.
"""
auth: str
dataset: Dataset | str
encoding: Encoding = Encoding.DBN
details: str | None = None
ts_out: str = "0"
compression: Compression | str = Compression.NONE
heartbeat_interval_s: int | None = None
slow_reader_behavior: SlowReaderBehavior | str | None = None
client: str = USER_AGENT
def __post_init__(self) -> None:
# Temporary work around for LSG support
if self.slow_reader_behavior in [SlowReaderBehavior.SKIP, "skip"]:
self.slow_reader_behavior = "drop"
@dataclasses.dataclass
class SubscriptionRequest(GatewayControl):
"""
A subscription request is sent to the gateway upon request from the client.
"""
schema: Schema | str
stype_in: SType
symbols: str
start: int | None = None
snapshot: int = 0
id: int | None = None
is_last: int = 1
@dataclasses.dataclass
class SessionStart(GatewayControl):
"""
A session start message is sent to the gateway upon request from the
client.
"""
start_session: str = "0"
def parse_gateway_message(line: str) -> GatewayControl:
"""
Parse a gateway message from a string.
Returns
-------
GatewayControl
Raises
------
ValueError
If `line` is not a parseable GatewayControl message.
"""
for message_cls in GatewayControl.__subclasses__():
try:
return message_cls.parse(line)
except ValueError:
continue
raise ValueError(f"'{line.strip()}' is not a parseable gateway message")
class GatewayDecoder:
"""
Decoder for gateway control messages.
"""
def __init__(self) -> None:
self.__buffer = BytesIO()
@property
def buffer(self) -> BytesIO:
"""
The internal buffer for decoding messages.
Returns
-------
BytesIO
"""
return self.__buffer
def write(self, data: bytes) -> None:
"""
Write data to the decoder's buffer. This will make the data available
for decoding.
Parameters
----------
data : bytes
The data to write.
"""
self.__buffer.seek(0, 2) # seek to end
self.__buffer.write(data)
def decode(self) -> list[GatewayControl]:
"""
Decode messages from the decoder's buffer. This will consume decoded
data from the buffer.
Returns
-------
list[GatewayControl]
"""
self.__buffer.seek(0) # rewind
buffer_lines = self.__buffer.getvalue().splitlines(keepends=True)
cursor = 0
messages = []
for line in buffer_lines:
if not line.endswith(b"\n"):
break
try:
message = parse_gateway_message(line.decode("utf-8"))
except ValueError:
logger.exception("could not parse gateway message: %s", line)
raise
else:
cursor += len(line)
messages.append(message)
self.__buffer = BytesIO(self.__buffer.getvalue()[cursor:])
return messages