1- from __future__ import annotations
2-
31import gc
42import os
53import shutil
64from abc import ABC
75from collections import defaultdict
8- from typing import TYPE_CHECKING , Any , Callable , Literal
6+ from collections .abc import Callable , Mapping
7+ from typing import Any , Literal
98
109import numpy as np
1110import torch
1211import wandb
1312from sklearn .metrics import f1_score
1413from torch import BoolTensor , Tensor , nn
1514from torch .nn .functional import softmax
15+ from torch .utils .data import DataLoader
1616from torch .utils .tensorboard import SummaryWriter
1717from tqdm import tqdm
1818
1919from aviary import ROOT
20-
21- if TYPE_CHECKING :
22- from collections .abc import Mapping
23-
24- from torch .utils .data import DataLoader
25-
26- from aviary .data import InMemoryDataLoader
20+ from aviary .data import InMemoryDataLoader , Normalizer
2721
2822TaskType = Literal ["regression" , "classification" ]
2923
@@ -129,6 +123,14 @@ def fit(
129123 for metric , val in metrics .items ():
130124 writer .add_scalar (f"{ task } /train/{ metric } " , val , epoch )
131125
126+ if writer == "wandb" :
127+ flat_train_metrics = {}
128+ for task , metrics in train_metrics .items ():
129+ for metric , val in metrics .items ():
130+ flat_train_metrics [f"train_{ task } _{ metric .lower ()} " ] = val
131+ flat_train_metrics ["epoch" ] = epoch
132+ wandb .log (flat_train_metrics )
133+
132134 # Validation
133135 if val_loader is not None :
134136 with torch .no_grad ():
@@ -149,6 +151,14 @@ def fit(
149151 f"{ task } /validation/{ metric } " , val , epoch
150152 )
151153
154+ if writer == "wandb" :
155+ flat_val_metrics = {}
156+ for task , metrics in val_metrics .items ():
157+ for metric , val in metrics .items ():
158+ flat_val_metrics [f"val_{ task } _{ metric .lower ()} " ] = val
159+ flat_val_metrics ["epoch" ] = epoch
160+ wandb .log (flat_val_metrics )
161+
152162 # TODO test all tasks to see if they are best,
153163 # save a best model if any is best.
154164 # TODO what are the costs of this approach.
@@ -207,9 +217,6 @@ def fit(
207217 # catch memory leak
208218 gc .collect ()
209219
210- if writer == "wandb" :
211- wandb .log ({"train" : train_metrics , "validation" : val_metrics })
212-
213220 except KeyboardInterrupt :
214221 pass
215222
@@ -271,7 +278,11 @@ def evaluate(
271278 mixed_loss : Tensor = 0 # type: ignore[assignment]
272279
273280 for target_name , targets , output , normalizer in zip (
274- self .target_names , targets_list , outputs , normalizer_dict .values ()
281+ self .target_names ,
282+ targets_list ,
283+ outputs ,
284+ normalizer_dict .values (),
285+ strict = False ,
275286 ):
276287 task , loss_func = loss_dict [target_name ]
277288 target_metrics = epoch_metrics [target_name ]
@@ -318,7 +329,7 @@ def evaluate(
318329 else :
319330 raise ValueError (f"invalid task: { task } " )
320331
321- epoch_metrics [ target_name ] ["Loss" ].append (loss .cpu ().item ())
332+ target_metrics ["Loss" ].append (loss .cpu ().item ())
322333
323334 # NOTE multitasking currently just uses a direct sum of individual
324335 # target losses this should be okay but is perhaps sub-optimal
@@ -396,11 +407,13 @@ def predict(
396407 # for multitask learning
397408 targets = tuple (
398409 torch .cat (targets , dim = 0 ).view (- 1 ).cpu ().numpy ()
399- for targets in zip (* test_targets )
410+ for targets in zip (* test_targets , strict = False )
411+ )
412+ predictions = tuple (
413+ torch .cat (preds , dim = 0 ) for preds in zip (* test_preds , strict = False )
400414 )
401- predictions = tuple (torch .cat (preds , dim = 0 ) for preds in zip (* test_preds ))
402415 # identifier columns
403- ids = tuple (np .concatenate (x ) for x in zip (* test_ids ))
416+ ids = tuple (np .concatenate (x ) for x in zip (* test_ids , strict = False ))
404417 return targets , predictions , ids
405418
406419 @torch .no_grad ()
@@ -445,83 +458,6 @@ def __repr__(self) -> str:
445458 return f"{ cls_name } with { n_params :,} trainable params at { n_epochs :,} epochs"
446459
447460
448- class Normalizer :
449- """Normalize a Tensor and restore it later."""
450-
451- def __init__ (self ) -> None :
452- """Initialize Normalizer with mean 0 and std 1."""
453- self .mean = torch .tensor (0 )
454- self .std = torch .tensor (1 )
455-
456- def fit (self , tensor : Tensor , dim : int = 0 , keepdim : bool = False ) -> None :
457- """Compute the mean and standard deviation of the given tensor.
458-
459- Args:
460- tensor (Tensor): Tensor to determine the mean and standard deviation over.
461- dim (int, optional): Which dimension to take mean and standard deviation
462- over. Defaults to 0.
463- keepdim (bool, optional): Whether to keep the reduced dimension in Tensor.
464- Defaults to False.
465- """
466- self .mean = torch .mean (tensor , dim , keepdim )
467- self .std = torch .std (tensor , dim , keepdim )
468-
469- def norm (self , tensor : Tensor ) -> Tensor :
470- """Normalize a Tensor.
471-
472- Args:
473- tensor (Tensor): Tensor to be normalized
474-
475- Returns:
476- Tensor: Normalized Tensor
477- """
478- return (tensor - self .mean ) / self .std
479-
480- def denorm (self , normed_tensor : Tensor ) -> Tensor :
481- """Restore normalized Tensor to original.
482-
483- Args:
484- normed_tensor (Tensor): Tensor to be restored
485-
486- Returns:
487- Tensor: Restored Tensor
488- """
489- return normed_tensor * self .std + self .mean
490-
491- def state_dict (self ) -> dict [str , Tensor ]:
492- """Get Normalizer parameters mean and std.
493-
494- Returns:
495- dict[str, Tensor]: Dictionary storing Normalizer parameters.
496- """
497- return {"mean" : self .mean , "std" : self .std }
498-
499- def load_state_dict (self , state_dict : dict [str , Tensor ]) -> None :
500- """Overwrite Normalizer parameters given a new state_dict.
501-
502- Args:
503- state_dict (dict[str, Tensor]): Dictionary storing Normalizer parameters.
504- """
505- self .mean = state_dict ["mean" ].cpu ()
506- self .std = state_dict ["std" ].cpu ()
507-
508- @classmethod
509- def from_state_dict (cls , state_dict : dict [str , Tensor ]) -> Normalizer :
510- """Create a new Normalizer given a state_dict.
511-
512- Args:
513- state_dict (dict[str, Tensor]): Dictionary storing Normalizer parameters.
514-
515- Returns:
516- Normalizer
517- """
518- instance = cls ()
519- instance .mean = state_dict ["mean" ].cpu ()
520- instance .std = state_dict ["std" ].cpu ()
521-
522- return instance
523-
524-
525461def save_checkpoint (
526462 state : dict [str , Any ], is_best : bool , model_name : str , run_id : int
527463) -> None :
0 commit comments