Skip to content

Commit c4ed212

Browse files
authored
Allow component class to be serialized
1 parent 96f559d commit c4ed212

1 file changed

Lines changed: 11 additions & 20 deletions

File tree

pcs/component.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Generic, TypeVar
22

33
from omegaconf import DictConfig, OmegaConf
4-
54
from pcs.init import initialize_object_nones
65

76
T = TypeVar("T")
@@ -17,9 +16,9 @@ def init_with_conf(
1716

1817
def __init__(self, conf: DictConfig, runtime: T):
1918
duplicate_keys = set(conf.keys()).intersection(set(runtime.__dict__))
20-
assert (
21-
len(duplicate_keys) == 0
22-
), f"Error initializing component - duplicate keys found: {duplicate_keys}"
19+
assert len(duplicate_keys) == 0, (
20+
f"Error initializing component - duplicate keys found: {duplicate_keys}"
21+
)
2322

2423
super().__setattr__("conf", conf)
2524
super().__setattr__("runtime", runtime)
@@ -39,6 +38,8 @@ def get_runtime(self):
3938
return super().__getattribute__("runtime")
4039

4140
def __getattribute__(self, name):
41+
if name not in ["conf", "runtime"]:
42+
return super().__getattribute__(name)
4243
conf: DictConfig = super().__getattribute__("conf")
4344
if name in conf.keys():
4445
if name in conf:
@@ -47,28 +48,26 @@ def __getattribute__(self, name):
4748
runtime: T = super().__getattribute__("runtime")
4849
if hasattr(runtime, name):
4950
return getattr(runtime, name)
50-
return super().__getattribute__(name)
5151

52-
def __setattr__(self, name, item):
52+
def __setattr__(self, name: str, item):
53+
if name in ["conf", "runtime"]:
54+
return super().__setattr__(name, item)
5355
conf: DictConfig = super().__getattribute__("conf")
5456
if name in conf.keys():
5557
return setattr(conf, name, item)
5658
runtime: T = super().__getattribute__("runtime")
5759
if hasattr(runtime, name):
5860
return setattr(runtime, name, item)
59-
if name == "conf":
60-
return super().__setattr__("conf", conf)
61-
if name == "runtime":
62-
return super().__setattr__("runtime", runtime)
6361
raise AttributeError(f"{name} not found in class Component")
6462

6563
def __hasattr__(self, name):
66-
return hasattr(self.runtime, name) or hasattr(self.conf, name)
64+
return super().__hasattr__(name) or hasattr(self.runtime, name) or hasattr(self.conf, name)
6765

6866
def __repr__(self):
6967
conf = super().__getattribute__("conf")
7068
runtime = super().__getattribute__("runtime")
71-
return f"Component({conf=}, {runtime=})"
69+
sealed = super().__getattribute__("sealed")
70+
return f"Component({conf=}, {runtime=}, {sealed=})"
7271

7372
def get_non_null_members_as_dict(self):
7473
resolved_conf = OmegaConf.to_container(
@@ -79,11 +78,3 @@ def get_non_null_members_as_dict(self):
7978
runtime = super().__getattribute__("runtime")
8079
result |= {k: v for k, v in runtime.__dict__.items() if v is not None}
8180
return result
82-
83-
84-
def get_conf_as_dict(obj: object):
85-
return {
86-
name: obj.__getattribute__(name)
87-
for name in obj.__annotations__
88-
if obj.__getattribute__(name) is not None
89-
}

0 commit comments

Comments
 (0)