Skip to content

Commit bf0e6d3

Browse files
authored
Merge pull request #441 from Modalities/wandb_entity
added entity parameter for w&b logging
2 parents d328e2e + 930e4db commit bf0e6d3

3 files changed

Lines changed: 12 additions & 3 deletions

File tree

src/modalities/config/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ class EvaluationResultToDiscSubscriberConfig(BaseModel):
492492

493493
class WandBEvaluationResultSubscriberConfig(BaseModel):
494494
global_rank: int
495+
entity: Optional[str] = None
495496
project: str
496497
experiment_id: str
497498
mode: WandbMode

src/modalities/logging_broker/subscriber_impl/results_subscriber.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,16 @@ def __init__(
6565
project: str,
6666
experiment_id: str,
6767
mode: WandbMode,
68-
logging_directory: Path,
68+
logging_directory: Path | None,
6969
config_file_path: Path,
70+
entity: str | None = None,
7071
) -> None:
7172
super().__init__()
7273

7374
with open(config_file_path, "r", encoding="utf-8") as file:
7475
config = yaml.safe_load(file)
7576
self.run = wandb.init(
77+
entity=entity,
7678
project=project,
7779
name=experiment_id,
7880
mode=mode.value.lower(),
@@ -81,7 +83,7 @@ def __init__(
8183
settings=wandb.Settings(init_timeout=120),
8284
)
8385

84-
self.run.log_artifact(config_file_path, name=f"config_{wandb.run.id}", type="config")
86+
self.run.log_artifact(config_file_path, name=f"config_{self.run.id}", type="config")
8587

8688
def consume_dict(self, message_dict: dict[str, Any]):
8789
for k, v in message_dict.items():

src/modalities/logging_broker/subscriber_impl/subscriber_factory.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def get_wandb_result_subscriber(
6868
mode: WandbMode,
6969
config_file_path: Path,
7070
directory: Optional[Path] = None,
71+
entity: Optional[str] = None,
7172
) -> WandBEvaluationResultSubscriber:
7273
if global_rank == 0 and (mode != WandbMode.DISABLED):
7374
if directory is not None:
@@ -88,7 +89,12 @@ def get_wandb_result_subscriber(
8889
absolute_dir = None
8990

9091
result_subscriber = WandBEvaluationResultSubscriber(
91-
project, experiment_id, mode, absolute_dir, config_file_path
92+
project=project,
93+
experiment_id=experiment_id,
94+
mode=mode,
95+
logging_directory=absolute_dir,
96+
config_file_path=config_file_path,
97+
entity=entity,
9298
)
9399
else:
94100
result_subscriber = ResultsSubscriberFactory.get_dummy_result_subscriber()

0 commit comments

Comments
 (0)