-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathconfig.py
More file actions
131 lines (99 loc) · 3.97 KB
/
config.py
File metadata and controls
131 lines (99 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from dataclasses import dataclass
from os import path
import json
from torch.optim.lr_scheduler import ExponentialLR
from typing import List
from data_utils.dataset import load_splits
from model.paths import PATHSProcessor
from model.interface import RecursiveModel
from preprocess import loader
@dataclass
class ModelConfig:
pass
# Model configuration (model dependent)
@dataclass
class PATHSProcessorConfig(ModelConfig):
hierarchical_ctx: bool = True
slide_ctx_mode: str = "residual" # residual / concat / none
patch_embed_dim: int = 1024
dropout: float = 0.0
patch_size: int = 256 # only needed for visualisation etc. and not at train time
importance_mode: str = "mul" # mul / none
trans_dim: int = 192
trans_heads: int = 4
trans_layers: int = 2
pos_encoding_mode: str = "1d" # 1d / 2d
importance_mlp_hidden_dim: int = 128
hierarchical_ctx_mlp_hidden_dim: int = 256
lstm: bool = True
# Training stats etc (model independent)
@dataclass
class Config:
model_config: ModelConfig
# Recursion related
base_power: float
magnification_factor: int
num_levels: int
num_epochs: int
top_k_patches: List[int] # how many patches to keep at each level; -1 denotes keep all patches
model_type: str
# Data
wsi_dir: str
csv_path: str
nbins: int = 4
loss: str = "nll"
task: str = "survival" # survival / subtype_classification
filter_to_subtypes: List[str] = None
preprocess_dir: str = None
# Training
batch_size: int = 32
save_epochs: int = 10
eval_epochs: int = 1
lr: float = 2e-5
lr_decay_per_epoch: float = 0.99
seed: int = 0
early_stopping: bool = False
weight_decay: float = 1e-2
min_epochs: int = 0 # min epochs for early stopping
root_name: str = "" # for tracking multiple folds
hipt_splits: bool = False
hipt_val_proportion: float = 0 # Split part of the HIPT training set off into a val set
@staticmethod
def load(root_path: str, test_mode: bool = False):
"""
Loads a Config object from [root_path]/config.json.
"""
jsonpath = path.join(root_path, "config.json")
assert path.isdir(root_path), f"Model directory '{root_path}' not found!"
assert path.isfile(jsonpath), f"config.json not found in directory '{root_path}'."
with open(jsonpath, "r") as file:
data = json.loads(file.read())
if isinstance(data["top_k_patches"], int):
data["top_k_patches"] = [data["top_k_patches"]] * (data["num_levels"] - 1)
if isinstance(data["num_epochs"], list):
data["num_epochs"] = data["num_epochs"][0]
if isinstance(data["batch_size"], int):
data["batch_size"] = [data["batch_size"]] * data["num_levels"]
if data["model_type"] == "PATHS":
data["model_config"] = PATHSProcessorConfig(**data["model_config"])
c = data["model_config"]
if c.lstm:
assert c.hierarchical_ctx, "If LSTM mode is enabled, hierarchical context must be enabled."
else:
raise NotImplementedError(f"Unknown model type '{data['model_type']}'")
config = Config(**data)
if not test_mode:
loader.set_preprocess_dir(config.preprocess_dir)
return config
def power_levels(self):
return [self.base_power * self.magnification_factor ** i for i in range(self.num_levels)]
def get_model(self) -> RecursiveModel:
if self.model_type == "PATHS":
return RecursiveModel(PATHSProcessor, self.model_config, train_config=self)
else:
raise NotImplementedError(f"Unknown model '{self.model_type}'.")
# Load train/test/val split with proportions given by props (a list of 3 floats)
def get_dataset(self, props, seed, ctx_dim, **kwargs):
return load_splits(props, seed, ctx_dim, self, **kwargs)
def get_lr_scheduler(self, optimizer):
return ExponentialLR(optimizer, self.lr_decay_per_epoch)