Skip to content

Commit 5ce927d

Browse files
author
Jeremy Manning
committed
updated doc strings
1 parent 0bfa9da commit 5ce927d

1 file changed

Lines changed: 48 additions & 14 deletions

File tree

seqnmf/seqnmf.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,40 @@
55
from matplotlib import gridspec
66
from .helpers import reconstruct, shift_factors, compute_loadings_percent_power, get_shapes
77

8-
def seqnmf(X, K=10, L=100, Lambda=.001, W_init=None, H_init=None, \
9-
plot_it=True, max_iter=100, tol=-np.inf, shift=True, sort_factors=True, \
10-
lambda_L1W=0, lambda_L1H=0, lambda_OrthH=0, lambda_OrthW=0, M=None, \
11-
use_W_update=True, W_fixed=False):
128

9+
def seqnmf(X, K=10, L=100, Lambda=.001, W_init=None, H_init=None,
10+
plot_it=False, max_iter=100, tol=-np.inf, shift=True, sort_factors=True,
11+
lambda_L1W=0, lambda_L1H=0, lambda_OrthH=0, lambda_OrthW=0, M=None,
12+
use_W_update=True, W_fixed=False):
13+
'''
14+
:param X: an N (features) by T (timepoints) data matrix to be factorized using seqNMF
15+
:param K: the (maximum) number of factors to search for; any unused factors will be set to all zeros
16+
:param L: the (maximum) number of timepoints to consider in each factor; any unused timepoints will be set to zeros
17+
:param Lambda: regularization parameter (default: 0.001)
18+
:param W_init: initial factors (if unspecified, use random initialization)
19+
:param H_init: initial per-timepoint factor loadings (if unspecified, initialize randomly)
20+
:param plot_it: if True, display progress in each update using a plot (default: False)
21+
:param max_iter: maximum number of iterations/updates
22+
:param tol: if cost is within tol of the average of the previous 5 updates, the algorithm will terminate (default: tol = -inf)
23+
:param shift: allow timepoint shifts in H
24+
:param sort_factors: sort factors by time
25+
:param lambda_L1W: regularization parameter for W (default: 0)
26+
:param lambda_L1H: regularization parameter for H (default: 0)
27+
:param lambda_OrthH: regularization parameter for H (default: 0)
28+
:param lambda_OrthW: regularization parameter for W (default: 0)
29+
:param M: binary mask of the same size as X, used to ignore a subset of the data during training (default: use all data)
30+
:param use_W_update: set to True for more accurate results; set to False for faster results (default: True)
31+
:param W_fixed: if true, fix factors (W), e.g. for cross validation (default: False)
32+
33+
:return:
34+
:W: N (features) by K (factors) by L (per-factor timepoints) tensor of factors
35+
:H: K (factors) by T (timepoints) matrix of factor loadings (i.e. factor timecourses)
36+
:cost: a vector of length (number-of-iterations + 1) containing the initial cost and cost after each update (i.e. the reconstruction error)
37+
:loadings: the per-factor loadings-- i.e. the explanatory power of each individual factor
38+
:power: the total power (across all factors) explained by the full reconstruction
39+
'''
1340
N = X.shape[0]
14-
T = X.shape[1] + 2*L
41+
T = X.shape[1] + 2 * L
1542
X = np.concatenate((np.zeros([N, L]), X, np.zeros([N, L])), axis=1)
1643

1744
if W_init is None:
@@ -38,8 +65,8 @@ def seqnmf(X, K=10, L=100, Lambda=.001, W_init=None, H_init=None, \
3865
cost[0] = np.sqrt(np.mean(np.power(X - X_hat, 2)))
3966

4067
for i in np.arange(max_iter):
41-
if (i == max_iter-1) or ((i > 6) and (cost[i + 1] + tol) > np.mean(cost[i - 6:i])):
42-
cost = cost[:(i+2)]
68+
if (i == max_iter - 1) or ((i > 6) and (cost[i + 1] + tol) > np.mean(cost[i - 6:i])):
69+
cost = cost[:(i + 2)]
4370
last_time = True
4471
if i > 0:
4572
Lambda = 0
@@ -107,7 +134,7 @@ def seqnmf(X, K=10, L=100, Lambda=.001, W_init=None, H_init=None, \
107134
X_hat = reconstruct(W, H)
108135
mask = M == 0
109136
X[mask] = X_hat[mask]
110-
cost[i+1] = np.sqrt(np.mean(np.power(X - X_hat, 2)))
137+
cost[i + 1] = np.sqrt(np.mean(np.power(X - X_hat, 2)))
111138

112139
if plot_it:
113140
if i > 0:
@@ -139,7 +166,15 @@ def seqnmf(X, K=10, L=100, Lambda=.001, W_init=None, H_init=None, \
139166

140167
return W, H, cost, loadings, power
141168

169+
142170
def plot(W, H, cmap='gray_r', factor_cmap='Spectral'):
171+
'''
172+
:param W: N (features) by K (factors) by L (per-factor timepoints) tensor of factors
173+
:param H: K (factors) by T (timepoints) matrix of factor loadings (i.e. factor timecourses)
174+
:param cmap: colormap used to draw heatmaps for the factors, factor loadings, and data reconstruction
175+
:param factor_cmap: colormap used to distinguish individual factors
176+
:return f: matplotlib figure handle
177+
'''
143178
N, K, L, T = get_shapes(W, H)
144179

145180
data_recon = reconstruct(W, H)
@@ -150,20 +185,19 @@ def plot(W, H, cmap='gray_r', factor_cmap='Spectral'):
150185
ax_w = plt.subplot(gs[2])
151186
ax_data = plt.subplot(gs[3])
152187

153-
#plot W, H, and data_recon
188+
# plot W, H, and data_recon
154189
sns.heatmap(np.hstack(list(map(np.squeeze, np.split(W, K, axis=1)))), cmap=cmap, ax=ax_w, cbar=False)
155190
sns.heatmap(H, cmap=cmap, ax=ax_h, cbar=False)
156191
sns.heatmap(data_recon, cmap=cmap, ax=ax_data, cbar=False)
157192

158-
#add dividing bars for factors of W and H
193+
# add dividing bars for factors of W and H
159194
factor_colors = sns.color_palette(factor_cmap, K)
160195
for k in np.arange(K):
161196
plt.sca(ax_w)
162-
start_w = k*L
163-
plt.plot([start_w, start_w], [0, N-1], '-', color=factor_colors[k])
197+
start_w = k * L
198+
plt.plot([start_w, start_w], [0, N - 1], '-', color=factor_colors[k])
164199

165200
plt.sca(ax_h)
166-
plt.plot([0, T-1], [k, k], '-', color=factor_colors[k])
201+
plt.plot([0, T - 1], [k, k], '-', color=factor_colors[k])
167202

168203
return fig
169-

0 commit comments

Comments
 (0)