Skip to content

Commit cac6f58

Browse files
authored
feat: Zero-downtime support for Database migration (#831)
# Description The old 0.3 version is not able to read the 1.0 entries from database because of the inconsistencies between 1.0 and 0.3 data types. This applies to both DatabaseTaskStore and DatabasePushNotificationConfigStore. This PR fixes this issue by allowing users to write 0.3 compatible entires during migration period. ## Changes - adds new conversion methods to `compat/0_3/conversions.py` - update `DatabaseTaskStore` and `DatabasePushNotificationConfigStore` to accept new conversion methods - utilize new conversion methods of `DatabaseTaskStore` and `DatabasePushNotificationConfigStore` ## Tested Created a database using `0.3` spec containing populated tables `task` and `push_notification_configs`. Ran `uv run a2a-db` using `1.0` spec against the database and added new entries using the new Zero-downtime feature, `DatabaseTaskStore.core_to_model_conversion = core_to_compat_task_model` and `DatabasePushNotificationConfigStore.core_to_model_conversion = core_to_compat_push_notification_config_model`. Succesfully read new entries using `0.3` spec. Fixes #811 🦕
1 parent dedda6c commit cac6f58

6 files changed

Lines changed: 407 additions & 54 deletions

File tree

src/a2a/compat/v0_3/conversions.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import base64
22

3-
from typing import Any
3+
from typing import TYPE_CHECKING, Any
4+
5+
6+
if TYPE_CHECKING:
7+
from cryptography.fernet import Fernet
48

59
from google.protobuf.json_format import MessageToDict, ParseDict
610

711
from a2a.compat.v0_3 import types as types_v03
12+
from a2a.server.models import PushNotificationConfigModel, TaskModel
813
from a2a.types import a2a_pb2 as pb2_v10
914

1015

@@ -1367,3 +1372,77 @@ def to_compat_get_extended_agent_card_request(
13671372
) -> types_v03.GetAuthenticatedExtendedCardRequest:
13681373
"""Convert get extended agent card request to v0.3 compat type."""
13691374
return types_v03.GetAuthenticatedExtendedCardRequest(id=request_id)
1375+
1376+
1377+
def core_to_compat_task_model(task: pb2_v10.Task, owner: str) -> TaskModel:
1378+
"""Converts a 1.0 core Task to a TaskModel using v0.3 JSON structure."""
1379+
compat_task = to_compat_task(task)
1380+
data = compat_task.model_dump(mode='json')
1381+
1382+
return TaskModel(
1383+
id=task.id,
1384+
context_id=task.context_id,
1385+
owner=owner,
1386+
status=data.get('status'),
1387+
history=data.get('history'),
1388+
artifacts=data.get('artifacts'),
1389+
task_metadata=data.get('metadata'),
1390+
protocol_version='0.3',
1391+
)
1392+
1393+
1394+
def compat_task_model_to_core(task_model: TaskModel) -> pb2_v10.Task:
1395+
"""Converts a TaskModel with v0.3 structure to a 1.0 core Task."""
1396+
compat_task = types_v03.Task(
1397+
id=task_model.id,
1398+
context_id=task_model.context_id,
1399+
status=types_v03.TaskStatus.model_validate(task_model.status),
1400+
artifacts=(
1401+
[types_v03.Artifact.model_validate(a) for a in task_model.artifacts]
1402+
if task_model.artifacts
1403+
else []
1404+
),
1405+
history=(
1406+
[types_v03.Message.model_validate(h) for h in task_model.history]
1407+
if task_model.history
1408+
else []
1409+
),
1410+
metadata=task_model.task_metadata,
1411+
)
1412+
return to_core_task(compat_task)
1413+
1414+
1415+
def core_to_compat_push_notification_config_model(
1416+
task_id: str,
1417+
config: pb2_v10.TaskPushNotificationConfig,
1418+
owner: str,
1419+
fernet: 'Fernet | None' = None,
1420+
) -> PushNotificationConfigModel:
1421+
"""Converts a 1.0 core TaskPushNotificationConfig to a PushNotificationConfigModel using v0.3 JSON structure."""
1422+
compat_config = to_compat_push_notification_config(config)
1423+
1424+
json_payload = compat_config.model_dump_json().encode('utf-8')
1425+
data_to_store = fernet.encrypt(json_payload) if fernet else json_payload
1426+
1427+
return PushNotificationConfigModel(
1428+
task_id=task_id,
1429+
config_id=config.id,
1430+
owner=owner,
1431+
config_data=data_to_store,
1432+
protocol_version='0.3',
1433+
)
1434+
1435+
1436+
def compat_push_notification_config_model_to_core(
1437+
model_instance: str, task_id: str
1438+
) -> pb2_v10.TaskPushNotificationConfig:
1439+
"""Converts a PushNotificationConfigModel with v0.3 structure back to a 1.0 core TaskPushNotificationConfig."""
1440+
inner_config = types_v03.PushNotificationConfig.model_validate_json(
1441+
model_instance
1442+
)
1443+
return to_core_task_push_notification_config(
1444+
types_v03.TaskPushNotificationConfig(
1445+
task_id=task_id,
1446+
push_notification_config=inner_config,
1447+
)
1448+
)

src/a2a/server/tasks/database_push_notification_config_store.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
AsyncSession,
1414
async_sessionmaker,
1515
)
16-
from sqlalchemy.orm import (
17-
class_mapper,
18-
)
16+
from sqlalchemy.orm import class_mapper
1917
except ImportError as e:
2018
raise ImportError(
2119
'DatabasePushNotificationConfigStore requires SQLAlchemy and a database driver. '
@@ -26,8 +24,11 @@
2624
"or 'pip install a2a-sdk[sql]'"
2725
) from e
2826

29-
from a2a.compat.v0_3 import conversions
30-
from a2a.compat.v0_3 import types as types_v03
27+
from collections.abc import Callable
28+
29+
from a2a.compat.v0_3.conversions import (
30+
compat_push_notification_config_model_to_core,
31+
)
3132
from a2a.server.context import ServerCallContext
3233
from a2a.server.models import (
3334
Base,
@@ -44,7 +45,6 @@
4445
if TYPE_CHECKING:
4546
from cryptography.fernet import Fernet
4647

47-
4848
logger = logging.getLogger(__name__)
4949

5050

@@ -61,14 +61,34 @@ class DatabasePushNotificationConfigStore(PushNotificationConfigStore):
6161
config_model: type[PushNotificationConfigModel]
6262
_fernet: 'Fernet | None'
6363
owner_resolver: OwnerResolver
64+
core_to_model_conversion: (
65+
Callable[
66+
[str, TaskPushNotificationConfig, str, 'Fernet | None'],
67+
PushNotificationConfigModel,
68+
]
69+
| None
70+
)
71+
model_to_core_conversion: (
72+
Callable[[PushNotificationConfigModel], TaskPushNotificationConfig]
73+
| None
74+
)
6475

65-
def __init__(
76+
def __init__( # noqa: PLR0913
6677
self,
6778
engine: AsyncEngine,
6879
create_table: bool = True,
6980
table_name: str = 'push_notification_configs',
7081
encryption_key: str | bytes | None = None,
7182
owner_resolver: OwnerResolver = resolve_user_scope,
83+
core_to_model_conversion: Callable[
84+
[str, TaskPushNotificationConfig, str, 'Fernet | None'],
85+
PushNotificationConfigModel,
86+
]
87+
| None = None,
88+
model_to_core_conversion: Callable[
89+
[PushNotificationConfigModel], TaskPushNotificationConfig
90+
]
91+
| None = None,
7292
) -> None:
7393
"""Initializes the DatabasePushNotificationConfigStore.
7494
@@ -80,6 +100,8 @@ def __init__(
80100
If provided, `config_data` will be encrypted in the database.
81101
The key must be a URL-safe base64-encoded 32-byte key.
82102
owner_resolver: Function to resolve the owner from the context.
103+
core_to_model_conversion: Optional function to convert a TaskPushNotificationConfig to a TaskPushNotificationConfigModel.
104+
model_to_core_conversion: Optional function to convert a TaskPushNotificationConfigModel to a TaskPushNotificationConfig.
83105
"""
84106
logger.debug(
85107
'Initializing DatabasePushNotificationConfigStore with existing engine, table: %s',
@@ -98,6 +120,8 @@ def __init__(
98120
else create_push_notification_config_model(table_name)
99121
)
100122
self._fernet = None
123+
self.core_to_model_conversion = core_to_model_conversion
124+
self.model_to_core_conversion = model_to_core_conversion
101125

102126
if encryption_key:
103127
try:
@@ -152,6 +176,11 @@ def _to_orm(
152176
153177
The config data is serialized to JSON bytes, and encrypted if a key is configured.
154178
"""
179+
if self.core_to_model_conversion:
180+
return self.core_to_model_conversion(
181+
task_id, config, owner, self._fernet
182+
)
183+
155184
json_payload = MessageToJson(config).encode('utf-8')
156185

157186
if self._fernet:
@@ -174,6 +203,9 @@ def _from_orm(
174203
175204
Handles decryption if a key is configured, with a fallback to plain JSON.
176205
"""
206+
if self.model_to_core_conversion:
207+
return self.model_to_core_conversion(model_instance)
208+
177209
payload = model_instance.config_data
178210

179211
if self._fernet:
@@ -359,12 +391,7 @@ def _parse_config(
359391
"""
360392
if protocol_version == '1.0':
361393
return Parse(json_payload, TaskPushNotificationConfig())
362-
inner_config = types_v03.PushNotificationConfig.model_validate_json(
363-
json_payload
364-
)
365-
return conversions.to_core_task_push_notification_config(
366-
types_v03.TaskPushNotificationConfig(
367-
task_id=task_id or '',
368-
push_notification_config=inner_config,
369-
)
394+
395+
return compat_push_notification_config_model_to_core(
396+
json_payload, task_id or ''
370397
)

src/a2a/server/tasks/database_task_store.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,17 @@
11
import logging
22

3+
from collections.abc import Callable
34
from datetime import datetime, timezone
45

56

67
try:
7-
from sqlalchemy import (
8-
Table,
9-
and_,
10-
delete,
11-
func,
12-
or_,
13-
select,
14-
)
8+
from sqlalchemy import Table, and_, delete, func, or_, select
159
from sqlalchemy.ext.asyncio import (
1610
AsyncEngine,
1711
AsyncSession,
1812
async_sessionmaker,
1913
)
20-
from sqlalchemy.orm import (
21-
class_mapper,
22-
)
14+
from sqlalchemy.orm import class_mapper
2315
except ImportError as e:
2416
raise ImportError(
2517
'DatabaseTaskStore requires SQLAlchemy and a database driver. '
@@ -29,11 +21,11 @@
2921
"'pip install a2a-sdk[sqlite]', "
3022
"or 'pip install a2a-sdk[sql]'"
3123
) from e
32-
3324
from google.protobuf.json_format import MessageToDict, ParseDict
3425

35-
from a2a.compat.v0_3 import conversions
36-
from a2a.compat.v0_3 import types as types_v03
26+
from a2a.compat.v0_3.conversions import (
27+
compat_task_model_to_core,
28+
)
3729
from a2a.server.context import ServerCallContext
3830
from a2a.server.models import Base, TaskModel, create_task_model
3931
from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope
@@ -60,13 +52,18 @@ class DatabaseTaskStore(TaskStore):
6052
_initialized: bool
6153
task_model: type[TaskModel]
6254
owner_resolver: OwnerResolver
55+
core_to_model_conversion: Callable[[Task, str], TaskModel] | None = None
56+
model_to_core_conversion: Callable[[TaskModel], Task] | None = None
6357

64-
def __init__(
58+
def __init__( # noqa: PLR0913
6559
self,
6660
engine: AsyncEngine,
6761
create_table: bool = True,
6862
table_name: str = 'tasks',
6963
owner_resolver: OwnerResolver = resolve_user_scope,
64+
core_to_model_conversion: Callable[[Task, str], TaskModel]
65+
| None = None,
66+
model_to_core_conversion: Callable[[TaskModel], Task] | None = None,
7067
) -> None:
7168
"""Initializes the DatabaseTaskStore.
7269
@@ -75,6 +72,8 @@ def __init__(
7572
create_table: If true, create tasks table on initialization.
7673
table_name: Name of the database table. Defaults to 'tasks'.
7774
owner_resolver: Function to resolve the owner from the context.
75+
core_to_model_conversion: Optional function to convert a Task to a TaskModel.
76+
model_to_core_conversion: Optional function to convert a TaskModel to a Task.
7877
"""
7978
logger.debug(
8079
'Initializing DatabaseTaskStore with existing engine, table: %s',
@@ -87,6 +86,8 @@ def __init__(
8786
self.create_table = create_table
8887
self._initialized = False
8988
self.owner_resolver = owner_resolver
89+
self.core_to_model_conversion = core_to_model_conversion
90+
self.model_to_core_conversion = model_to_core_conversion
9091

9192
self.task_model = (
9293
TaskModel
@@ -119,6 +120,9 @@ async def _ensure_initialized(self) -> None:
119120

120121
def _to_orm(self, task: Task, owner: str) -> TaskModel:
121122
"""Maps a Proto Task to a SQLAlchemy TaskModel instance."""
123+
if self.core_to_model_conversion:
124+
return self.core_to_model_conversion(task, owner)
125+
122126
return self.task_model(
123127
id=task.id,
124128
context_id=task.context_id,
@@ -140,6 +144,9 @@ def _to_orm(self, task: Task, owner: str) -> TaskModel:
140144

141145
def _from_orm(self, task_model: TaskModel) -> Task:
142146
"""Maps a SQLAlchemy TaskModel to a Proto Task instance."""
147+
if self.model_to_core_conversion:
148+
return self.model_to_core_conversion(task_model)
149+
143150
if task_model.protocol_version == '1.0':
144151
task = Task(
145152
id=task_model.id,
@@ -160,29 +167,7 @@ def _from_orm(self, task_model: TaskModel) -> Task:
160167
return task
161168

162169
# Legacy conversion
163-
legacy_task = types_v03.Task(
164-
id=task_model.id,
165-
context_id=task_model.context_id,
166-
status=types_v03.TaskStatus.model_validate(task_model.status),
167-
artifacts=(
168-
[
169-
types_v03.Artifact.model_validate(a)
170-
for a in task_model.artifacts
171-
]
172-
if task_model.artifacts
173-
else []
174-
),
175-
history=(
176-
[
177-
types_v03.Message.model_validate(m)
178-
for m in task_model.history
179-
]
180-
if task_model.history
181-
else []
182-
),
183-
metadata=task_model.task_metadata or {},
184-
)
185-
return conversions.to_core_task(legacy_task)
170+
return compat_task_model_to_core(task_model)
186171

187172
async def save(
188173
self, task: Task, context: ServerCallContext | None = None

0 commit comments

Comments
 (0)