1+ """
2+ In this example, a regression model with the ability to predict a mean and
3+ standard deviation is created and trained using torchfitter.
4+
5+ By predicting a mean and a std. one can define some sort of uncertainty
6+ interval around the predictions (a.k.a. how sure is my model about the
7+ prediction of this sample?).
8+ """
9+
10+ import torch
11+ import argparse
12+ import torch .nn as nn
13+ import torch .optim as optim
14+ import matplotlib .pyplot as plt
15+ from torchfitter .conventions import ParamsDict
16+ from sklearn .datasets import make_regression
17+ from torchfitter .utils .preprocessing import train_test_val_split , torch_to_numpy
18+ from torchfitter .trainer import Trainer
19+ from torch .utils .data import DataLoader
20+ from torchfitter .utils .data import DataWrapper
21+ from torchfitter .callbacks import RichProgressBar , EarlyStopping
22+
23+
24+ class DeepNormal (nn .Module ):
25+ """Neural network with parametrizable normal distribution as output.
26+
27+ Taken from [1].
28+
29+ References
30+ ----------
31+ .. [1] Romain Strock - Modeling uncertainty with Pytorch:
32+ https://romainstrock.com/blog/modeling-uncertainty-with-pytorch.html
33+ """
34+ def __init__ (self , n_inputs , n_hidden ):
35+ super ().__init__ ()
36+
37+ # Shared parameters
38+ self .shared_layer = nn .Sequential (
39+ nn .Linear (n_inputs , n_hidden ),
40+ nn .ReLU (),
41+ nn .Dropout (),
42+ )
43+
44+ # Mean parameters
45+ self .mean_layer = nn .Sequential (
46+ nn .Linear (n_hidden , n_hidden ),
47+ nn .ReLU (),
48+ nn .Dropout (),
49+ nn .Linear (n_hidden , 1 ),
50+ )
51+
52+ # Standard deviation parameters
53+ self .std_layer = nn .Sequential (
54+ nn .Linear (n_hidden , n_hidden ),
55+ nn .ReLU (),
56+ nn .Dropout (),
57+ nn .Linear (n_hidden , 1 ),
58+ nn .Softplus (), # enforces positivity
59+ )
60+
61+ def forward (self , x ):
62+ # Shared embedding
63+ shared = self .shared_layer (x )
64+
65+ # Parametrization of the mean
66+ mean = self .mean_layer (shared )
67+
68+ # Parametrization of the standard deviation
69+ std = self .std_layer (shared )
70+
71+ return torch .distributions .Normal (mean , std )
72+
73+
74+ class NLLLoss (nn .Module ):
75+ def __init__ (self ):
76+ super ().__init__ ()
77+
78+ def forward (self , output , target ):
79+ """
80+ Assumes `output` is a distribution.
81+ """
82+ neg_log_likelihood = - output .log_prob (target )
83+ return torch .mean (neg_log_likelihood )
84+
85+
86+ def main ():
87+ # -------------------------------------------------------------------------
88+ # argument parsing
89+ parser = argparse .ArgumentParser ("" )
90+ parser .add_argument ("--epochs" , type = int , default = 5000 )
91+
92+ args = parser .parse_args ()
93+ n_epochs = args .epochs
94+
95+ # -------------------------------------------------------------------------
96+ # generate dummy data
97+ X , y = make_regression (
98+ n_samples = 5000 , n_features = 1 , n_informative = 1 , noise = 5 , random_state = 0
99+ )
100+ y = y .reshape (- 1 ,1 )
101+
102+ # split data into train, test and validation
103+ _tup = train_test_val_split (X , y )
104+ X_train , y_train , X_val , y_val , X_test , y_test = _tup
105+
106+ # wrap data in Dataset
107+ train_wrapper = DataWrapper (
108+ X_train , y_train , dtype_X = "float" , dtype_y = "float"
109+ )
110+ val_wrapper = DataWrapper (X_val , y_val , dtype_X = "float" , dtype_y = "float" )
111+
112+ # torch Loaders
113+ train_loader = DataLoader (train_wrapper , batch_size = 64 , pin_memory = True )
114+ val_loader = DataLoader (val_wrapper , batch_size = 64 , pin_memory = True )
115+
116+ # -------------------------------------------------------------------------
117+ # define model, optimizer and loss
118+ criterion = NLLLoss ()
119+ model = DeepNormal (n_inputs = X .shape [1 ], n_hidden = 15 )
120+ optimizer = optim .AdamW (model .parameters (), lr = 1e-3 )
121+
122+ # callbacks list
123+ callbacks = [
124+ EarlyStopping (patience = 150 , load_best = True ),
125+ RichProgressBar (display_step = 50 )
126+ ]
127+
128+ # instantiate Trainer object with all the configuration
129+ trainer = Trainer (
130+ model = model ,
131+ criterion = criterion ,
132+ optimizer = optimizer ,
133+ callbacks = callbacks ,
134+ )
135+
136+ # train process
137+ history = trainer .fit (train_loader , val_loader , epochs = n_epochs )
138+
139+ # -------------------------------------------------------------------------
140+ # this is a torch distribution
141+ distr_prediction = trainer .predict (X_test )
142+
143+ # get mean and standard deviation for each sample in test
144+ y_pred = distr_prediction .mean
145+ y_pred_std = distr_prediction .stddev
146+
147+ # to array
148+ y_pred = torch_to_numpy (y_pred )
149+ y_pred_std = torch_to_numpy (y_pred_std )
150+
151+ # -------------------------------------------------------------------------
152+ # plot losses, mean predictions and lr
153+ fig , ax = plt .subplots (nrows = 1 , ncols = 3 , figsize = (19 , 4 ))
154+ epoch_hist = history [ParamsDict .EPOCH_HISTORY ]
155+
156+ ax [0 ].plot (epoch_hist [ParamsDict .LOSS ]["train" ], label = "Train loss" )
157+ ax [0 ].plot (
158+ epoch_hist [ParamsDict .LOSS ]["validation" ], label = "Validation loss"
159+ )
160+ ax [0 ].set_title ("Train and validation losses" )
161+ ax [0 ].grid ()
162+ ax [0 ].legend ()
163+
164+ ax [1 ].plot (X_test , y_test , "." , label = "Real" )
165+ ax [1 ].plot (X_test , y_pred , "." , label = "Prediction" )
166+ ax [1 ].set_title ("Predictions" )
167+ ax [1 ].grid ()
168+ ax [1 ].legend ()
169+
170+ ax [2 ].plot (epoch_hist [ParamsDict .HISTORY_LR ], label = "Learning rate" )
171+ ax [2 ].set_title ("Learning Rate" )
172+ ax [2 ].legend ()
173+ ax [2 ].grid ()
174+ plt .show ()
175+
176+ # -------------------------------------------------------------------------
177+ # create some upper and lower bounds
178+ lower = y_pred - 2 * y_pred_std
179+ upper = y_pred + 2 * y_pred_std
180+
181+ fig , ax = plt .subplots (1 , 1 , figsize = (15 ,8 ))
182+
183+ ax .plot (X_test , y_test , "*k" )
184+ ax .scatter (X_test .flatten (), y_pred , label = "predicted means" )
185+
186+ ax .scatter (X_test .flatten (), lower )
187+ ax .scatter (X_test .flatten (), upper )
188+
189+ ax .grid (True )
190+ ax .legend ()
191+
192+
193+ if __name__ == "__main__" :
194+ main ()
0 commit comments