Skip to content

Commit 23ac9e0

Browse files
committed
Fix docs
1 parent d722c5c commit 23ac9e0

9 files changed

Lines changed: 69 additions & 101 deletions

File tree

docs/source/_templates/autosummary/module.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
{{ fullname | escape | underline}}
22

33
.. automodule:: {{ fullname }}
4+
:no-members:
5+
:no-undoc-members:
6+
:no-special-members:
47

58
{% block attributes %}
69
{% if attributes %}

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
templates_path = ["_templates"]
3838
source_suffix = [".rst"]
3939
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
40+
suppress_warnings = ["myst.xref_missing"]
4041

4142
# The toctree master document
4243
master_doc = "index"

src/ezmsg/learn/dim_reduce/adaptive_decomp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,11 @@ class MiniBatchNMFSettings(AdaptiveDecompSettings):
227227

228228
beta_loss: typing.Union[str, float] = "frobenius"
229229
"""
230-
'frobenius', 'kullback-leibler', 'itakura-saito'
230+
'frobenius', 'kullback-leibler', 'itakura-saito'.
231231
Note that values different from 'frobenius'
232-
(or 2) and 'kullback-leibler' (or 1) lead to significantly slower
233-
fits. Note that for `beta_loss <= 0` (or 'itakura-saito'), the input
234-
matrix `X` cannot contain zeros.
232+
(or 2) and 'kullback-leibler' (or 1) lead to significantly slower
233+
fits. Note that for ``beta_loss <= 0`` (or 'itakura-saito'), the input
234+
matrix ``X`` cannot contain zeros.
235235
"""
236236

237237
tol: float = 1e-4

src/ezmsg/learn/linear_model/sgd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
SGDDecoderSettings as SGDDecoderSettings,
66
)
77
from ..process.sgd import (
8-
sgd_decoder as sgd_decoder,
8+
SGDDecoderTransformer as SGDDecoderTransformer,
99
)

src/ezmsg/learn/model/mlp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
):
2626
"""
2727
Initialize the MLP model.
28+
2829
Args:
2930
input_size (int): The size of the input features.
3031
hidden_size (int | list[int]): The sizes of the hidden layers. If a list, num_layers must be None or the

src/ezmsg/learn/model/rnn.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@ def __init__(
3737
rnn_klass_str = rnn_type.upper().split("-")[0]
3838
if rnn_klass_str not in ["GRU", "LSTM", "RNN"]:
3939
raise ValueError(f"Unrecognized rnn_type: {rnn_type}")
40-
rnn_klass = {"GRU": torch.nn.GRU, "LSTM": torch.nn.LSTM, "RNN": torch.nn.RNN}[
41-
rnn_klass_str
42-
]
40+
rnn_klass = {"GRU": torch.nn.GRU, "LSTM": torch.nn.LSTM, "RNN": torch.nn.RNN}[rnn_klass_str]
4341
rnn_kwargs = {}
4442
if rnn_klass_str == "RNN":
4543
rnn_kwargs["nonlinearity"] = rnn_type.lower().split("-")[-1]
@@ -57,16 +55,11 @@ def __init__(
5755
if isinstance(output_size, int):
5856
output_size = {"output": output_size}
5957
self.heads = torch.nn.ModuleDict(
60-
{
61-
name: torch.nn.Linear(hidden_size, size)
62-
for name, size in output_size.items()
63-
}
58+
{name: torch.nn.Linear(hidden_size, size) for name, size in output_size.items()}
6459
)
6560

6661
@classmethod
67-
def infer_config_from_state_dict(
68-
cls, state_dict: dict, rnn_type: str = "GRU"
69-
) -> dict[str, int | float]:
62+
def infer_config_from_state_dict(cls, state_dict: dict, rnn_type: str = "GRU") -> dict[str, int | float]:
7063
"""
7164
This method is specific to each processor.
7265
@@ -88,8 +81,7 @@ def infer_config_from_state_dict(
8881
# Infer input_size from linear_embeddings.weight (shape: [input_size, input_size])
8982
"input_size": state_dict["linear_embeddings.weight"].shape[1],
9083
# Infer hidden_size from rnn.weight_ih_l0 (shape: [hidden_size * 3, input_size])
91-
"hidden_size": state_dict["rnn.weight_ih_l0"].shape[0]
92-
// cls._get_gate_count(rnn_type),
84+
"hidden_size": state_dict["rnn.weight_ih_l0"].shape[0] // cls._get_gate_count(rnn_type),
9385
# Infer num_layers by counting rnn layers in state_dict (e.g., weight_ih_l<k>)
9486
"num_layers": sum(1 for key in state_dict if "rnn.weight_ih_l" in key),
9587
"output_size": output_size,
@@ -134,27 +126,25 @@ def forward(
134126
) -> tuple[dict[str, torch.Tensor], torch.Tensor | tuple]:
135127
"""
136128
Forward pass through the RNN model.
129+
137130
Args:
138131
x (torch.Tensor): Input tensor of shape (batch, seq_len, input_size).
139132
input_lens (Optional[torch.Tensor]): Optional tensor of lengths for each sequence in the batch.
140133
If provided, sequences will be packed before passing through the RNN.
141134
hx (Optional[torch.Tensor | tuple[torch.Tensor, torch.Tensor]]): Optional initial hidden state for the RNN.
135+
142136
Returns:
143137
tuple[dict[str, torch.Tensor], torch.Tensor | tuple]:
144-
A dictionary mapping head names to output tensors of shape (batch, seq_len, output_size).
145-
If the RNN is LSTM, the second element is the hidden state (h_n, c_n) or just h_n if GRU.
138+
A dictionary mapping head names to output tensors of shape (batch, seq_len, output_size).
139+
If the RNN is LSTM, the second element is the hidden state (h_n, c_n) or just h_n if GRU.
146140
"""
147141
x = self.linear_embeddings(x)
148142
x = self.dropout_input(x)
149143
total_length = x.shape[1]
150144
if input_lens is not None:
151-
x = torch.nn.utils.rnn.pack_padded_sequence(
152-
x, input_lens, batch_first=True, enforce_sorted=False
153-
)
145+
x = torch.nn.utils.rnn.pack_padded_sequence(x, input_lens, batch_first=True, enforce_sorted=False)
154146
x_out, hx_out = self.rnn(x, hx)
155147
if input_lens is not None:
156-
x_out, _ = torch.nn.utils.rnn.pad_packed_sequence(
157-
x_out, batch_first=True, total_length=total_length
158-
)
148+
x_out, _ = torch.nn.utils.rnn.pad_packed_sequence(x_out, batch_first=True, total_length=total_length)
159149
x_out = self.output_dropout(x_out)
160150
return {name: head(x_out) for name, head in self.heads.items()}, hx_out

src/ezmsg/learn/model/transformer.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@ def __init__(
4949
else:
5050
autoregressive_size = list(output_size.values())[0]
5151
if isinstance(output_size, dict):
52-
autoregressive_size = output_size.get(
53-
autoregressive_head, autoregressive_size
54-
)
52+
autoregressive_size = output_size.get(autoregressive_head, autoregressive_size)
5553
self.start_token = torch.nn.Parameter(torch.zeros(1, 1, autoregressive_size))
5654
self.output_to_hidden = torch.nn.Linear(autoregressive_size, hidden_size)
5755

@@ -86,10 +84,7 @@ def __init__(
8684
if isinstance(output_size, int):
8785
output_size = {"output": output_size}
8886
self.heads = torch.nn.ModuleDict(
89-
{
90-
name: torch.nn.Linear(hidden_size, out_dim)
91-
for name, out_dim in output_size.items()
92-
}
87+
{name: torch.nn.Linear(hidden_size, out_dim) for name, out_dim in output_size.items()}
9388
)
9489

9590
@classmethod
@@ -108,13 +103,9 @@ def infer_config_from_state_dict(cls, state_dict: dict) -> dict[str, int | float
108103
"hidden_size": state_dict["input_proj.weight"].shape[0],
109104
"output_size": output_size,
110105
# Infer encoder_layers from transformer layers in state_dict
111-
"encoder_layers": len(
112-
[k for k in state_dict if k.startswith("encoder.layers")]
113-
),
106+
"encoder_layers": len([k for k in state_dict if k.startswith("encoder.layers")]),
114107
# Infer decoder_layers from transformer decoder layers in state_dict
115-
"decoder_layers": len(
116-
{k.split(".")[2] for k in state_dict if k.startswith("decoder.layers")}
117-
)
108+
"decoder_layers": len({k.split(".")[2] for k in state_dict if k.startswith("decoder.layers")})
118109
if any(k.startswith("decoder.layers") for k in state_dict)
119110
else 0,
120111
}
@@ -129,20 +120,22 @@ def forward(
129120
) -> dict[str, torch.Tensor]:
130121
"""
131122
Forward pass through the transformer model.
123+
132124
Args:
133125
src (torch.Tensor): Input tensor of shape (batch, seq_len, input_size).
134126
tgt (Optional[torch.Tensor]): Target tensor for decoder, shape (batch, seq_len, input_size).
135-
Required if `decoder_layers > 0`. In training, this can be the ground-truth target sequence
127+
Required if ``decoder_layers > 0``. In training, this can be the ground-truth target sequence
136128
(i.e. teacher forcing). During inference, this is constructed autoregressively.
137129
src_mask (Optional[torch.Tensor]): Optional attention mask for the encoder input. Should be broadcastable
138130
to shape (batch, seq_len, seq_len) or (seq_len, seq_len).
139131
tgt_mask (Optional[torch.Tensor]): Optional attention mask for the decoder input. Used to enforce causal
140132
decoding (i.e. autoregressive generation) during training or inference.
141133
start_pos (int): Starting offset for positional embeddings. Used for streaming inference to maintain
142134
correct positional indices. Default is 0.
135+
143136
Returns:
144-
dict[str, torch.Tensor]: Dictionary of output tensors each output head, each with shape (batch, seq_len,
145-
output_size).
137+
dict[str, torch.Tensor]: Dictionary of output tensors each output head, each with shape
138+
(batch, seq_len, output_size).
146139
"""
147140
B, T, _ = src.shape
148141
device = src.device
@@ -158,9 +151,7 @@ def forward(
158151
if tgt is None:
159152
tgt = self.start_token.expand(B, -1, -1).to(device)
160153
tgt_proj = self.output_to_hidden(tgt)
161-
tgt_pos_ids = torch.arange(tgt.shape[1], device=device).expand(
162-
B, tgt.shape[1]
163-
)
154+
tgt_pos_ids = torch.arange(tgt.shape[1], device=device).expand(B, tgt.shape[1])
164155
tgt_proj = tgt_proj + self.pos_embedding(tgt_pos_ids)
165156
tgt_proj = self.dropout(tgt_proj)
166157
out = self.decoder(

src/ezmsg/learn/process/refit_kalman.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,15 @@ class RefitKalmanFilterSettings(ez.Settings):
3131
3232
This class defines the configuration parameters for the Refit Kalman filter processor.
3333
The RefitKalmanFilter is designed for online processing and playback.
34-
35-
Attributes:
36-
checkpoint_path: Path to saved model parameters (optional).
37-
If provided, loads pre-trained parameters instead of learning from data.
38-
steady_state: Whether to use steady-state Kalman filter.
39-
If True, uses pre-computed Kalman gain; if False, updates dynamically.
4034
"""
4135

4236
checkpoint_path: str | None = None
37+
"""Path to saved model parameters. If provided, loads pre-trained parameters instead of learning from data."""
38+
4339
steady_state: bool = False
40+
"""Whether to use steady-state Kalman filter. If True, uses pre-computed Kalman gain;
41+
if False, updates dynamically."""
42+
4443
velocity_indices: tuple[int, int] = (2, 3)
4544

4645

@@ -51,28 +50,31 @@ class RefitKalmanFilterState:
5150
5251
This class manages the persistent state of the Refit Kalman filter processor,
5352
including the model instance, current state estimates, and data buffers for refitting.
54-
55-
Attributes:
56-
model: The RefitKalmanFilter model instance.
57-
x: Current state estimate (n_states,).
58-
P: Current state covariance matrix (n_states x n_states).
59-
buffer_neural: Buffer for storing neural activity data for refitting.
60-
buffer_state: Buffer for storing state estimates for refitting.
61-
buffer_cursor_positions: Buffer for storing cursor positions for refitting.
62-
buffer_target_positions: Buffer for storing target positions for refitting.
63-
buffer_hold_flags: Buffer for storing hold flags for refitting.
64-
current_position: Current cursor position estimate (2,).
6553
"""
6654

6755
model: RefitKalmanFilter | None = None
56+
"""The RefitKalmanFilter model instance."""
57+
6858
x: object | None = None # Array API; namespace matches source data.
59+
"""Current state estimate (n_states,)."""
60+
6961
P: object | None = None # Array API; namespace matches source data.
62+
"""Current state covariance matrix (n_states x n_states)."""
7063

7164
buffer_neural: list | None = None
65+
"""Buffer for storing neural activity data for refitting."""
66+
7267
buffer_state: list | None = None
68+
"""Buffer for storing state estimates for refitting."""
69+
7370
buffer_cursor_positions: list | None = None
71+
"""Buffer for storing cursor positions for refitting."""
72+
7473
buffer_target_positions: list | None = None
74+
"""Buffer for storing target positions for refitting."""
75+
7576
buffer_hold_flags: list | None = None
77+
"""Buffer for storing hold flags for refitting."""
7678

7779

7880
class RefitKalmanFilterProcessor(
@@ -382,10 +384,8 @@ def refit_model(self):
382384
Refit the observation model (H, Q) using buffered measurements and contextual data.
383385
384386
This method updates the model's understanding of the neural-to-state mapping
385-
by calculating a new observation matrix and noise covariance, based on:
386-
- Logged neural data
387-
- Cursor state estimates
388-
- Hold flags and target positions
387+
by calculating a new observation matrix and noise covariance, based on
388+
logged neural data, cursor state estimates, hold flags, and target positions.
389389
390390
Args:
391391
velocity_indices (tuple): Indices in the state vector corresponding to velocity components.

src/ezmsg/learn/process/sklearn.py

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -55,33 +55,17 @@ class SklearnModelProcessor(BaseAdaptiveTransformer[SklearnModelSettings, AxisAr
5555
5656
The processor expects and outputs `AxisArray` messages with a `"ch"` (channel) axis.
5757
58-
Settings:
59-
---------
60-
model_class : str
61-
Full path to the sklearn or River model class to use.
62-
Example: "sklearn.linear_model.SGDClassifier" or "river.linear_model.LogisticRegression"
63-
64-
model_kwargs : dict[str, typing.Any], optional
65-
Additional keyword arguments passed to the model constructor.
66-
67-
checkpoint_path : str, optional
68-
Path to a pickle file to load a previously saved model. If provided, the model will
69-
be restored from this path at startup.
70-
71-
partial_fit_classes : np.ndarray, optional
72-
For classifiers that require all class labels to be specified during `partial_fit`.
73-
74-
Example:
75-
-----------------------------
76-
```python
77-
processor = SklearnModelProcessor(
78-
settings=SklearnModelSettings(
79-
model_class='sklearn.linear_model.SGDClassifier',
80-
model_kwargs={'loss': 'log_loss'},
81-
partial_fit_classes=np.array([0, 1]),
58+
See :class:`SklearnModelSettings` for configuration options.
59+
60+
Example::
61+
62+
processor = SklearnModelProcessor(
63+
settings=SklearnModelSettings(
64+
model_class='sklearn.linear_model.SGDClassifier',
65+
model_kwargs={'loss': 'log_loss'},
66+
partial_fit_classes=np.array([0, 1]),
67+
)
8268
)
83-
)
84-
```
8569
"""
8670

8771
def _init_model(self) -> None:
@@ -224,17 +208,15 @@ class SklearnModelUnit(BaseAdaptiveTransformerUnit[SklearnModelSettings, AxisArr
224208
in an ezmsg graph-based system. It takes in `AxisArray` inputs and outputs predictions
225209
in the same format, optionally performing training via `partial_fit` or `fit`.
226210
227-
Example:
228-
--------
229-
```python
230-
unit = SklearnModelUnit(
231-
settings=SklearnModelSettings(
232-
model_class='sklearn.linear_model.SGDClassifier',
233-
model_kwargs={'loss': 'log_loss'},
234-
partial_fit_classes=np.array([0, 1]),
211+
Example::
212+
213+
unit = SklearnModelUnit(
214+
settings=SklearnModelSettings(
215+
model_class='sklearn.linear_model.SGDClassifier',
216+
model_kwargs={'loss': 'log_loss'},
217+
partial_fit_classes=np.array([0, 1]),
218+
)
235219
)
236-
)
237-
```
238220
"""
239221

240222
SETTINGS = SklearnModelSettings

0 commit comments

Comments
 (0)