Skip to content

Commit b04ee18

Browse files
authored
Merge pull request #1 from ezmsg-org/cboulay/migration
Bring in some basic ML nodes
2 parents c1374ad + cc2a023 commit b04ee18

26 files changed

Lines changed: 5148 additions & 9 deletions

.github/workflows/python-publish-ezmsg-learn.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
- uses: actions/checkout@v4
1818

1919
- name: Install uv
20-
uses: astral-sh/setup-uv@v2
20+
uses: astral-sh/setup-uv@v6
2121

2222
- name: Build Package
2323
run: uv build

.github/workflows/python-tests.yml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ name: Test package
22

33
on:
44
push:
5-
branches: [main]
5+
branches:
6+
- main
7+
- dev
68
pull_request:
79
branches:
810
- main
11+
- dev
912
workflow_dispatch:
1013

1114
jobs:
@@ -23,16 +26,12 @@ jobs:
2326
- uses: actions/checkout@v4
2427

2528
- name: Install uv
26-
uses: astral-sh/setup-uv@v2
29+
uses: astral-sh/setup-uv@v6
2730
with:
28-
enable-cache: true
29-
cache-dependency-glob: "uv.lock"
30-
31-
- name: Set up Python ${{ matrix.python-version }}
32-
run: uv python install ${{ matrix.python-version }}
31+
python-version: ${{ matrix.python-version }}
3332

3433
- name: Install the project
35-
run: uv sync --all-extras
34+
run: uv sync
3635

3736
- name: Lint
3837
run:

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
repos:
2+
- repo: https://github.com/astral-sh/ruff-pre-commit
3+
rev: v0.11.12
4+
hooks:
5+
- id: ruff
6+
args: [ --fix ]
7+
- id: ruff-format

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ lint = [
2525
"ruff>=0.12.9",
2626
]
2727
test = [
28+
"hmmlearn>=0.3.3",
2829
"pytest>=8.4.1",
2930
]
3031

@@ -40,3 +41,9 @@ version-file = "src/ezmsg/learn/__version__.py"
4041

4142
[tool.hatch.build.targets.wheel]
4243
packages = ["src/ezmsg"]
44+
45+
[tool.pytest.ini_options]
46+
pythonpath = [
47+
"src",
48+
".",
49+
]

src/ezmsg/learn/model/mlp.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import torch
2+
import torch.nn
3+
4+
5+
class MLP(torch.nn.Module):
6+
"""
7+
A simple Multi-Layer Perceptron (MLP) model. Adapted from Ezmsg MLP.
8+
9+
Attributes:
10+
feature_extractor (torch.nn.Sequential): The sequential feature extractor part of the MLP.
11+
heads (torch.nn.ModuleDict): A dictionary of output linear layers for each output head.
12+
"""
13+
14+
def __init__(
15+
self,
16+
input_size: int,
17+
hidden_size: int | list[int],
18+
num_layers: int | None = None,
19+
output_heads: int | dict[str, int] = 2,
20+
norm_layer: str | None = None,
21+
activation_layer: str | None = "ReLU",
22+
inplace: bool | None = None,
23+
bias: bool = True,
24+
dropout: float = 0.0,
25+
):
26+
"""
27+
Initialize the MLP model.
28+
Args:
29+
input_size (int): The size of the input features.
30+
hidden_size (int | list[int]): The sizes of the hidden layers. If a list, num_layers must be None or the length
31+
of the list. If a single integer, num_layers must be specified and determines the number of hidden layers.
32+
num_layers (int, optional): The number of hidden layers. Length of hidden_size if None. Default is None.
33+
output_heads (int | dict[str, int], optional): Number of output features or classes if single head output or a
34+
dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
35+
norm_layer (str, optional): A normalization layer to be applied after each linear layer. Default is None.
36+
Common choices are "BatchNorm1d" or "LayerNorm".
37+
activation_layer (str, optional): An activation function to be applied after each normalization
38+
layer. Default is "ReLU".
39+
inplace (bool, optional): Whether the activation function is performed in-place. Default is None.
40+
bias (bool, optional): Whether to use bias in the linear layers. Default is True.
41+
dropout (float, optional): The dropout rate to be applied after each linear layer. Default is 0.0.
42+
"""
43+
super().__init__()
44+
if isinstance(hidden_size, int):
45+
if num_layers is None:
46+
raise ValueError(
47+
"If hidden_size is an integer, num_layers must be specified."
48+
)
49+
hidden_size = [hidden_size] * num_layers
50+
if len(hidden_size) == 0:
51+
raise ValueError("hidden_size must have at least one element")
52+
if any(not isinstance(x, int) for x in hidden_size):
53+
raise ValueError("hidden_size must contain only integers")
54+
if num_layers is not None and len(hidden_size) != num_layers:
55+
raise ValueError(
56+
"Length of hidden_size must match num_layers if num_layers is specified."
57+
)
58+
59+
params = {} if inplace is None else {"inplace": inplace}
60+
61+
layers = []
62+
in_dim = input_size
63+
64+
def _get_layer_class(layer_name: str):
65+
if layer_name is not None and "torch.nn" in layer_name:
66+
return getattr(torch.nn, layer_name.rsplit(".", 1)[1])
67+
return None
68+
69+
norm_layer_class = _get_layer_class(norm_layer)
70+
activation_layer_class = _get_layer_class(activation_layer)
71+
for hidden_dim in hidden_size[:-1]:
72+
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
73+
if norm_layer_class is not None:
74+
layers.append(norm_layer_class(hidden_dim))
75+
if activation_layer_class is not None:
76+
layers.append(activation_layer_class(**params))
77+
layers.append(torch.nn.Dropout(dropout, **params))
78+
in_dim = hidden_dim
79+
80+
layers.append(torch.nn.Linear(in_dim, hidden_size[-1], bias=bias))
81+
82+
self.feature_extractor = torch.nn.Sequential(*layers)
83+
84+
if isinstance(output_heads, int):
85+
output_heads = {"output": output_heads}
86+
self.heads = torch.nn.ModuleDict(
87+
{
88+
name: torch.nn.Linear(hidden_size[-1], output_size)
89+
for name, output_size in output_heads.items()
90+
}
91+
)
92+
93+
@classmethod
94+
def infer_config_from_state_dict(cls, state_dict: dict) -> dict[str, int | float]:
95+
"""
96+
Infer the configuration from the state dict.
97+
98+
Args:
99+
state_dict: The state dict of the model.
100+
101+
Returns:
102+
dict[str, int | float]: A dictionary containing the inferred configuration.
103+
"""
104+
input_size = state_dict["feature_extractor.0.weight"].shape[1]
105+
hidden_size = [
106+
param.shape[0]
107+
for key, param in state_dict.items()
108+
if key.startswith("feature_extractor.") and key.endswith(".weight")
109+
]
110+
output_heads = {
111+
key.split(".")[1]: param.shape[0]
112+
for key, param in state_dict.items()
113+
if key.startswith("heads.") and key.endswith(".bias")
114+
}
115+
116+
return {
117+
"input_size": input_size,
118+
"hidden_size": hidden_size,
119+
"output_heads": output_heads,
120+
}
121+
122+
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
123+
"""
124+
Forward pass through the MLP.
125+
126+
Args:
127+
x (torch.Tensor): Input tensor of shape (batch, seq_len, input_size).
128+
129+
Returns:
130+
dict[str, torch.Tensor]: A dictionary mapping head names to output tensors.
131+
"""
132+
x = self.feature_extractor(x)
133+
return {name: head(x) for name, head in self.heads.items()}

0 commit comments

Comments
 (0)