55from matplotlib import gridspec
66from .helpers import reconstruct , shift_factors , compute_loadings_percent_power , get_shapes
77
8- def seqnmf (X , K = 10 , L = 100 , Lambda = .001 , W_init = None , H_init = None , \
9- plot_it = True , max_iter = 100 , tol = - np .inf , shift = True , sort_factors = True , \
10- lambda_L1W = 0 , lambda_L1H = 0 , lambda_OrthH = 0 , lambda_OrthW = 0 , M = None , \
11- use_W_update = True , W_fixed = False ):
128
9+ def seqnmf (X , K = 10 , L = 100 , Lambda = .001 , W_init = None , H_init = None ,
10+ plot_it = False , max_iter = 100 , tol = - np .inf , shift = True , sort_factors = True ,
11+ lambda_L1W = 0 , lambda_L1H = 0 , lambda_OrthH = 0 , lambda_OrthW = 0 , M = None ,
12+ use_W_update = True , W_fixed = False ):
13+ '''
14+ :param X: an N (features) by T (timepoints) data matrix to be factorized using seqNMF
15+ :param K: the (maximum) number of factors to search for; any unused factors will be set to all zeros
16+ :param L: the (maximum) number of timepoints to consider in each factor; any unused timepoints will be set to zeros
17+ :param Lambda: regularization parameter (default: 0.001)
18+ :param W_init: initial factors (if unspecified, use random initialization)
19+ :param H_init: initial per-timepoint factor loadings (if unspecified, initialize randomly)
20+ :param plot_it: if True, display progress in each update using a plot (default: False)
21+ :param max_iter: maximum number of iterations/updates
22+ :param tol: if cost is within tol of the average of the previous 5 updates, the algorithm will terminate (default: tol = -inf)
23+ :param shift: allow timepoint shifts in H
24+ :param sort_factors: sort factors by time
25+ :param lambda_L1W: regularization parameter for W (default: 0)
26+ :param lambda_L1H: regularization parameter for H (default: 0)
27+ :param lambda_OrthH: regularization parameter for H (default: 0)
28+ :param lambda_OrthW: regularization parameter for W (default: 0)
29+ :param M: binary mask of the same size as X, used to ignore a subset of the data during training (default: use all data)
30+ :param use_W_update: set to True for more accurate results; set to False for faster results (default: True)
31+ :param W_fixed: if true, fix factors (W), e.g. for cross validation (default: False)
32+
33+ :return:
34+ :W: N (features) by K (factors) by L (per-factor timepoints) tensor of factors
35+ :H: K (factors) by T (timepoints) matrix of factor loadings (i.e. factor timecourses)
36+ :cost: a vector of length (number-of-iterations + 1) containing the initial cost and cost after each update (i.e. the reconstruction error)
37+ :loadings: the per-factor loadings-- i.e. the explanatory power of each individual factor
38+ :power: the total power (across all factors) explained by the full reconstruction
39+ '''
1340 N = X .shape [0 ]
14- T = X .shape [1 ] + 2 * L
41+ T = X .shape [1 ] + 2 * L
1542 X = np .concatenate ((np .zeros ([N , L ]), X , np .zeros ([N , L ])), axis = 1 )
1643
1744 if W_init is None :
@@ -38,8 +65,8 @@ def seqnmf(X, K=10, L=100, Lambda=.001, W_init=None, H_init=None, \
3865 cost [0 ] = np .sqrt (np .mean (np .power (X - X_hat , 2 )))
3966
4067 for i in np .arange (max_iter ):
41- if (i == max_iter - 1 ) or ((i > 6 ) and (cost [i + 1 ] + tol ) > np .mean (cost [i - 6 :i ])):
42- cost = cost [:(i + 2 )]
68+ if (i == max_iter - 1 ) or ((i > 6 ) and (cost [i + 1 ] + tol ) > np .mean (cost [i - 6 :i ])):
69+ cost = cost [:(i + 2 )]
4370 last_time = True
4471 if i > 0 :
4572 Lambda = 0
@@ -107,7 +134,7 @@ def seqnmf(X, K=10, L=100, Lambda=.001, W_init=None, H_init=None, \
107134 X_hat = reconstruct (W , H )
108135 mask = M == 0
109136 X [mask ] = X_hat [mask ]
110- cost [i + 1 ] = np .sqrt (np .mean (np .power (X - X_hat , 2 )))
137+ cost [i + 1 ] = np .sqrt (np .mean (np .power (X - X_hat , 2 )))
111138
112139 if plot_it :
113140 if i > 0 :
@@ -139,7 +166,15 @@ def seqnmf(X, K=10, L=100, Lambda=.001, W_init=None, H_init=None, \
139166
140167 return W , H , cost , loadings , power
141168
169+
142170def plot (W , H , cmap = 'gray_r' , factor_cmap = 'Spectral' ):
171+ '''
172+ :param W: N (features) by K (factors) by L (per-factor timepoints) tensor of factors
173+ :param H: K (factors) by T (timepoints) matrix of factor loadings (i.e. factor timecourses)
174+ :param cmap: colormap used to draw heatmaps for the factors, factor loadings, and data reconstruction
175+ :param factor_cmap: colormap used to distinguish individual factors
176+ :return f: matplotlib figure handle
177+ '''
143178 N , K , L , T = get_shapes (W , H )
144179
145180 data_recon = reconstruct (W , H )
@@ -150,20 +185,19 @@ def plot(W, H, cmap='gray_r', factor_cmap='Spectral'):
150185 ax_w = plt .subplot (gs [2 ])
151186 ax_data = plt .subplot (gs [3 ])
152187
153- #plot W, H, and data_recon
188+ # plot W, H, and data_recon
154189 sns .heatmap (np .hstack (list (map (np .squeeze , np .split (W , K , axis = 1 )))), cmap = cmap , ax = ax_w , cbar = False )
155190 sns .heatmap (H , cmap = cmap , ax = ax_h , cbar = False )
156191 sns .heatmap (data_recon , cmap = cmap , ax = ax_data , cbar = False )
157192
158- #add dividing bars for factors of W and H
193+ # add dividing bars for factors of W and H
159194 factor_colors = sns .color_palette (factor_cmap , K )
160195 for k in np .arange (K ):
161196 plt .sca (ax_w )
162- start_w = k * L
163- plt .plot ([start_w , start_w ], [0 , N - 1 ], '-' , color = factor_colors [k ])
197+ start_w = k * L
198+ plt .plot ([start_w , start_w ], [0 , N - 1 ], '-' , color = factor_colors [k ])
164199
165200 plt .sca (ax_h )
166- plt .plot ([0 , T - 1 ], [k , k ], '-' , color = factor_colors [k ])
201+ plt .plot ([0 , T - 1 ], [k , k ], '-' , color = factor_colors [k ])
167202
168203 return fig
169-
0 commit comments