66import torch .nn .functional as F
77
88class CensoredMixtureLogistic (nn .Module ):
9- def __init__ (self , n , res = 1e-2 , lo = '-inf' , hi = 'inf' , sharp_bounds = (1e-4 ,2e3 )):
9+ def __init__ (self , n , res = 1e-2 , lo = '-inf' , hi = 'inf' ,
10+ sharp_bounds = (1e-4 ,2e3 ), init = None ):
1011 super ().__init__ ()
1112 self .n = n
1213 self .res = res
1314 self .sharp_bounds = sharp_bounds
14- # self.base_bounds = tuple(reversed([math.exp(-s) for s in sharp_bounds]))
15- # self.register_buffer('max_sharp', torch.tensor(float(max_sharp)))
1615 self .register_buffer ('lo' , torch .tensor (float (lo )))
1716 self .register_buffer ('hi' , torch .tensor (float (hi )))
1817 # TODO: init is not general-purpose
1918 #TODO
20- self .bias = nn .Parameter (torch .cat ((
21- torch .zeros (n ), torch .logspace (- 3 ,1 ,n ), torch .zeros (n )
22- )))
19+ if init == 'time' :
20+ self .bias = nn .Parameter (torch .cat ((
21+ torch .zeros (n ), torch .logspace (- 3 ,1 ,n ), torch .zeros (n )
22+ )))
23+ elif init == 'velocity' :
24+ self .bias = nn .Parameter (torch .cat ((
25+ torch .zeros (n ), torch .linspace (1 ,126 ,n ), torch .zeros (n )
26+ )))
27+ else :
28+ self .bias = nn .Parameter (torch .cat ((
29+ torch .zeros (n ), torch .randn (n ), torch .zeros (n )
30+ )))
2331
2432 @property
2533 def n_params (self ):
@@ -28,7 +36,7 @@ def n_params(self):
2836 def get_params (self , h ):
2937 assert h .shape [- 1 ] == self .n_params
3038 h = h + self .bias
31- # get parameters fron unconstrained hidden state:
39+ # get parameters from unconstrained hidden state:
3240 logit_pi , loc , log_s = torch .chunk (h , 3 , - 1 )
3341 # mixture coefficients
3442 log_pi = logit_pi - logit_pi .logsumexp (- 1 ,keepdim = True )
@@ -37,48 +45,30 @@ def get_params(self, h):
3745 # sharpness
3846 s = F .softplus (log_s ).clamp (* self .sharp_bounds )
3947 return log_pi , loc , s
40- # exp of negative sharpness
41- # b = logit_b.sigmoid().clamp(*self.base_bounds)
42- # log_b = F.logsigmoid(logit_b).clamp(
43- # math.log(self.base_bounds[0]), math.log(self.base_bounds[1]))
44- # b = ((logit_b / (logit_b.abs() + 1))*.5+.5).clamp(*self.base_bounds)
45- # log_b = b.log()
46- # return log_pi, loc, b, log_b
48+
4749
4850 def forward (self , h , x ):
4951 """log prob of x under distribution parameterized by h"""
50- # log_pi, loc, b, log_b = self.get_params(h)
5152 log_pi , loc , s = self .get_params (h )
5253
5354 d = self .res / 2
54- x = x .clamp (self .lo - d , self .hi + d )[...,None ]
55+ x = x .clamp (self .lo , self .hi )[...,None ]
5556 x_ = (x - loc ) * s
56-
57- # numerical crimes follow
58- # q = b ** -x_
59- q = x_ .exp ()
6057 sd = s * d
61- # bdp, bdm = b**d, b**-d
62- sdm , sdp = (- sd ).exp (), sd .exp ()
63- # # censoring
64- lo_cens = x <= self .lo
65- hi_cens = x >= self .hi
66- ones = torch .ones_like (q )
67- zeros = torch .zeros_like (q )
68-
69- diff_term = torch .where (
70- lo_cens | hi_cens , ones , sdp - sdm ).log ()
71- # sdm_term = torch.where(
72- # hi_cens, ones, (q + sdm)).log()
73- sdm_term = torch .where (hi_cens , zeros , x_ + F .softplus (- sd - x_ ))
74- # sdp_term = torch.where(
75- # lo_cens, ones, (q + sdp)).log()
76- sdp_term = torch .where (lo_cens , zeros , x_ + F .softplus (sd - x_ ))
77- x_or_sd = torch .where (hi_cens , sd , x_ )
7858
79- log_delta_cdf = (
80- x_or_sd + diff_term - sdm_term - sdp_term
81- )
59+ # # censoring
60+ lo_cens = x <= self .lo + d
61+ hi_cens = x >= self .hi - d
62+ ones = torch .ones_like (x_ )
63+ zeros = torch .zeros_like (x_ )
64+
65+ diff_term = torch .where (lo_cens | hi_cens ,
66+ ones , sd .exp () - (- sd ).exp ()
67+ ).log ()
68+ minus_sp_term = torch .where (hi_cens , - sd , F .softplus (- sd - x_ ))
69+ plus_sp_term = torch .where (lo_cens , zeros , x_ + F .softplus (sd - x_ ))
70+
71+ log_delta_cdf = diff_term - minus_sp_term - plus_sp_term
8272
8373 # log prob
8474 r = {
@@ -87,23 +77,23 @@ def forward(self, h, x):
8777 # diagnostics
8878 with torch .no_grad ():
8979 ent = D .Categorical (logits = log_pi ).entropy ()
90- # s = -b.log()
9180 r |= {
92- 'min_sharpness' : s .min (),
81+ # 'min_sharpness': s.min(),
9382 'max_sharpness' : s .max (),
94- 'min_entropy' : ent .min (),
95- 'max_entropy' : ent .max (),
96- 'marginal_entropy' : D .Categorical (
83+ 'mean_sharpness' : (s * log_pi .exp ()).sum (- 1 ).mean (),
84+ # 'min_entropy': ent.min(),
85+ # 'max_entropy': ent.max(),
86+ 'mean_cmp_entropy' : ent .mean (),
87+ 'marginal_cmp_entropy' : D .Categorical (
9788 log_pi .exp ().mean (list (range (log_pi .ndim - 1 )))).entropy (),
98- 'min_loc' : loc .min (),
99- 'max_loc' : loc .max ()
89+ # 'min_loc': loc.min(),
90+ # 'max_loc': loc.max()
10091 }
10192 return r
10293
10394 def cdf (self , h , x ):
10495 log_pi , loc , s = self .get_params (h )
10596 x_ = (x [...,None ] - loc ) * s
106- # cdfs = 1 / (1 + b ** x_)
10797 cdfs = x_ .sigmoid ()
10898 cdf = (cdfs * log_pi .softmax (- 1 )).sum (- 1 )
10999 return cdf
@@ -113,9 +103,7 @@ def sample(self, h, shape=1):
113103 Args:
114104 shape: additional sample shape to be prepended to dims
115105 """
116- # log_pi, loc, _, log_b = self.get_params(h)
117106 log_pi , loc , s = self .get_params (h )
118- # scale = -1/log_b
119107 scale = 1 / s
120108
121109 c = D .Categorical (logits = log_pi ).sample ((shape ,))
0 commit comments