Skip to content

Commit e4bb183

Browse files
yfukaiJoOkuma
andauthored
Dict-like metadata interface & private metadata (#260)
* added private metadata machinery * before adding private * added private metadata view * renamed func * added private metadata to_geff * further fix --------- Co-authored-by: Jordão Bragantini <jordao.bragantini@czbiohub.org>
1 parent 55c858b commit e4bb183

11 files changed

Lines changed: 208 additions & 87 deletions

File tree

src/tracksdata/array/_graph_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _validate_shape(
2323
"""Helper function to validate the shape argument."""
2424
if shape is None:
2525
try:
26-
shape = graph.metadata()["shape"]
26+
shape = graph.metadata["shape"]
2727
except KeyError as e:
2828
raise KeyError(
2929
f"`shape` is required to `{func_name}`. "

src/tracksdata/functional/_test/test_napari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_napari_conversion(metadata_shape: bool) -> None:
3131

3232
shape = (2, 10, 22, 32)
3333
if metadata_shape:
34-
graph.update_metadata(shape=shape)
34+
graph.metadata.update(shape=shape)
3535
arg_shape = None
3636
else:
3737
arg_shape = shape

src/tracksdata/graph/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Graph backends for representing tracking data as directed graphs in memory or on disk."""
22

3-
from tracksdata.graph._base_graph import BaseGraph
3+
from tracksdata.graph._base_graph import BaseGraph, MetadataView
44
from tracksdata.graph._graph_view import GraphView
55
from tracksdata.graph._rustworkx_graph import IndexedRXGraph, RustWorkXGraph
66
from tracksdata.graph._sql_graph import SQLGraph
77

88
InMemoryGraph = RustWorkXGraph
99

10-
__all__ = ["BaseGraph", "GraphView", "InMemoryGraph", "IndexedRXGraph", "RustWorkXGraph", "SQLGraph"]
10+
__all__ = ["BaseGraph", "GraphView", "InMemoryGraph", "IndexedRXGraph", "MetadataView", "RustWorkXGraph", "SQLGraph"]

src/tracksdata/graph/_base_graph.py

Lines changed: 128 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,75 @@
4242
T = TypeVar("T", bound="BaseGraph")
4343

4444

45+
class MetadataView(dict[str, Any]):
46+
"""Dictionary-like metadata view that syncs mutations back to the graph."""
47+
48+
_MISSING = object()
49+
50+
def __init__(
51+
self,
52+
graph: "BaseGraph",
53+
data: dict[str, Any],
54+
*,
55+
is_public: bool = True,
56+
) -> None:
57+
super().__init__(data)
58+
self._graph = graph
59+
self._is_public = is_public
60+
61+
def __setitem__(self, key: str, value: Any) -> None:
62+
self._graph._set_metadata_with_validation(is_public=self._is_public, **{key: value})
63+
super().__setitem__(key, value)
64+
65+
def __delitem__(self, key: str) -> None:
66+
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
67+
super().__delitem__(key)
68+
69+
def pop(self, key: str, default: Any = _MISSING) -> Any:
70+
self._graph._validate_metadata_key(key, is_public=self._is_public)
71+
72+
if key not in self:
73+
if default is self._MISSING:
74+
raise KeyError(key)
75+
return default
76+
77+
value = super().__getitem__(key)
78+
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
79+
super().pop(key, None)
80+
return value
81+
82+
def popitem(self) -> tuple[str, Any]:
83+
key, value = super().popitem()
84+
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
85+
return key, value
86+
87+
def clear(self) -> None:
88+
keys = list(self.keys())
89+
for key in keys:
90+
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
91+
super().clear()
92+
93+
def setdefault(self, key: str, default: Any = None) -> Any:
94+
if key in self:
95+
return super().__getitem__(key)
96+
self._graph._set_metadata_with_validation(is_public=self._is_public, **{key: default})
97+
super().__setitem__(key, default)
98+
return default
99+
100+
def update(self, *args, **kwargs) -> None:
101+
updates = dict(*args, **kwargs)
102+
if updates:
103+
self._graph._set_metadata_with_validation(is_public=self._is_public, **updates)
104+
super().update(updates)
105+
106+
45107
class BaseGraph(abc.ABC):
46108
"""
47109
Base class for a graph backend.
48110
"""
49111

112+
_PRIVATE_METADATA_PREFIX = "__private_"
113+
50114
node_added = Signal(int, object)
51115
node_removed = Signal(int, object)
52116
node_updated = Signal(int, object, object)
@@ -1187,7 +1251,8 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T:
11871251
node_attrs = node_attrs.drop(DEFAULT_ATTR_KEYS.NODE_ID)
11881252

11891253
graph = cls(**kwargs)
1190-
graph.update_metadata(**other.metadata())
1254+
graph.metadata.update(other.metadata)
1255+
graph._private_metadata.update(other._private_metadata_for_copy())
11911256

11921257
current_node_attr_schemas = graph._node_attr_schemas()
11931258
for k, v in other._node_attr_schemas().items():
@@ -1792,7 +1857,8 @@ def to_geff(
17921857
for k, v in edge_attrs.to_dict().items()
17931858
}
17941859

1795-
td_metadata = self.metadata().copy()
1860+
td_metadata = self.metadata.copy()
1861+
td_metadata.update(self._private_metadata_for_copy())
17961862
td_metadata.pop("geff", None) # avoid geff being written multiple times
17971863

17981864
geff_metadata = geff.GeffMetadata(
@@ -1830,57 +1896,88 @@ def to_geff(
18301896
zarr_format=zarr_format,
18311897
)
18321898

1833-
@abc.abstractmethod
1834-
def metadata(self) -> dict[str, Any]:
1899+
@property
1900+
def metadata(self) -> MetadataView:
18351901
"""
18361902
Return the metadata of the graph.
18371903
18381904
Returns
18391905
-------
1840-
dict[str, Any]
1906+
MetadataView
18411907
The metadata of the graph as a dictionary.
18421908
18431909
Examples
18441910
--------
18451911
```python
1846-
metadata = graph.metadata()
1912+
metadata = graph.metadata
18471913
print(metadata["shape"])
18481914
```
18491915
"""
1916+
return MetadataView(
1917+
graph=self,
1918+
data={k: v for k, v in self._metadata().items() if not self._is_private_metadata_key(k)},
1919+
is_public=True,
1920+
)
18501921

1851-
@abc.abstractmethod
1852-
def update_metadata(self, **kwargs) -> None:
1922+
@property
1923+
def _private_metadata(self) -> MetadataView:
1924+
return MetadataView(
1925+
graph=self,
1926+
data={k: v for k, v in self._metadata().items() if self._is_private_metadata_key(k)},
1927+
is_public=False,
1928+
)
1929+
1930+
def _private_metadata_for_copy(self) -> dict[str, Any]:
18531931
"""
1854-
Set or update metadata for the graph.
1932+
Return private metadata entries that should be propagated by `from_other` or `to_geff`.
1933+
Backends can override this to exclude backend-specific private metadata.
1934+
"""
1935+
return dict(self._private_metadata)
18551936

1856-
Parameters
1857-
----------
1858-
**kwargs : Any
1859-
The metadata items to set by key. Values will be stored as JSON.
1937+
@classmethod
1938+
def _is_private_metadata_key(cls, key: str) -> bool:
1939+
return key.startswith(cls._PRIVATE_METADATA_PREFIX)
1940+
1941+
def _validate_metadata_key(self, key: str, *, is_public: bool) -> None:
1942+
if not isinstance(key, str):
1943+
raise TypeError(f"Metadata key must be a string. Got {type(key)}.")
1944+
is_private_key = self._is_private_metadata_key(key)
1945+
if is_public and is_private_key:
1946+
raise ValueError(f"Metadata key '{key}' is reserved for internal use.")
1947+
if not is_public and not is_private_key:
1948+
raise ValueError(
1949+
f"Metadata key '{key}' is not private. Private metadata keys must start with "
1950+
f"'{self._PRIVATE_METADATA_PREFIX}'."
1951+
)
18601952

1861-
Examples
1862-
--------
1863-
```python
1864-
graph.update_metadata(shape=[1, 25, 25], path="path/to/image.ome.zarr")
1865-
graph.update_metadata(description="Tracking data from experiment 1")
1866-
```
1867-
"""
1953+
def _validate_metadata_keys(self, keys: Sequence[str], *, is_public: bool) -> None:
1954+
for key in keys:
1955+
self._validate_metadata_key(key, is_public=is_public)
1956+
1957+
def _set_metadata_with_validation(self, is_public: bool = True, **kwargs) -> None:
1958+
self._validate_metadata_keys(kwargs.keys(), is_public=is_public)
1959+
self._update_metadata(**kwargs)
1960+
1961+
def _remove_metadata_with_validation(self, key: str, *, is_public: bool = True) -> None:
1962+
self._validate_metadata_key(key, is_public=is_public)
1963+
self._remove_metadata(key)
18681964

18691965
@abc.abstractmethod
1870-
def remove_metadata(self, key: str) -> None:
1966+
def _metadata(self) -> dict[str, Any]:
1967+
"""
1968+
Return the full metadata including private keys.
18711969
"""
1872-
Remove a metadata key from the graph.
18731970

1874-
Parameters
1875-
----------
1876-
key : str
1877-
The key of the metadata to remove.
1971+
@abc.abstractmethod
1972+
def _update_metadata(self, **kwargs) -> None:
1973+
"""
1974+
Backend-specific metadata update implementation without public key validation.
1975+
"""
18781976

1879-
Examples
1880-
--------
1881-
```python
1882-
graph.remove_metadata("shape")
1883-
```
1977+
@abc.abstractmethod
1978+
def _remove_metadata(self, key: str) -> None:
1979+
"""
1980+
Backend-specific metadata removal implementation without public key validation.
18841981
"""
18851982

18861983
def to_traccuracy_graph(self, array_view_kwargs: dict[str, Any] | None = None) -> "TrackingGraph":

src/tracksdata/graph/_graph_view.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -870,11 +870,11 @@ def copy(self, **kwargs) -> "GraphView":
870870
"Use `detach` to create a new reference-less graph with the same nodes and edges."
871871
)
872872

873-
def metadata(self) -> dict[str, Any]:
874-
return self._root.metadata()
873+
def _metadata(self) -> dict[str, Any]:
874+
return self._root._metadata()
875875

876-
def update_metadata(self, **kwargs) -> None:
877-
self._root.update_metadata(**kwargs)
876+
def _update_metadata(self, **kwargs) -> None:
877+
self._root._update_metadata(**kwargs)
878878

879-
def remove_metadata(self, key: str) -> None:
880-
self._root.remove_metadata(key)
879+
def _remove_metadata(self, key: str) -> None:
880+
self._root._remove_metadata(key)

src/tracksdata/graph/_rustworkx_graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None:
371371

372372
elif not isinstance(self._graph.attrs, dict):
373373
LOG.warning(
374-
"previous attribute %s will be added to key 'old_attrs' of `graph.metadata()`",
374+
"previous attribute %s will be added to key 'old_attrs' of `graph.metadata`",
375375
self._graph.attrs,
376376
)
377377
self._graph.attrs = {
@@ -1516,13 +1516,13 @@ def edge_id(self, source_id: int, target_id: int) -> int:
15161516
"""
15171517
return self.rx_graph.get_edge_data(source_id, target_id)[DEFAULT_ATTR_KEYS.EDGE_ID]
15181518

1519-
def metadata(self) -> dict[str, Any]:
1519+
def _metadata(self) -> dict[str, Any]:
15201520
return self._graph.attrs
15211521

1522-
def update_metadata(self, **kwargs) -> None:
1522+
def _update_metadata(self, **kwargs) -> None:
15231523
self._graph.attrs.update(kwargs)
15241524

1525-
def remove_metadata(self, key: str) -> None:
1525+
def _remove_metadata(self, key: str) -> None:
15261526
self._graph.attrs.pop(key, None)
15271527

15281528
def edge_list(self) -> list[list[int, int]]:

src/tracksdata/graph/_sql_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,19 +2076,19 @@ def remove_edge(
20762076
raise ValueError(f"Edge {edge_id} does not exist in the graph.")
20772077
session.commit()
20782078

2079-
def metadata(self) -> dict[str, Any]:
2079+
def _metadata(self) -> dict[str, Any]:
20802080
with Session(self._engine) as session:
20812081
result = session.query(self.Metadata).all()
20822082
return {row.key: row.value for row in result}
20832083

2084-
def update_metadata(self, **kwargs) -> None:
2084+
def _update_metadata(self, **kwargs) -> None:
20852085
with Session(self._engine) as session:
20862086
for key, value in kwargs.items():
20872087
metadata_entry = self.Metadata(key=key, value=value)
20882088
session.merge(metadata_entry)
20892089
session.commit()
20902090

2091-
def remove_metadata(self, key: str) -> None:
2091+
def _remove_metadata(self, key: str) -> None:
20922092
with Session(self._engine) as session:
20932093
session.query(self.Metadata).filter(self.Metadata.key == key).delete()
20942094
session.commit()

0 commit comments

Comments
 (0)