Skip to content

Add deep EBMs#66

Open
nbereux wants to merge 32 commits into
DsysDML:developfrom
AidanLiotard:develop
Open

Add deep EBMs#66
nbereux wants to merge 32 commits into
DsysDML:developfrom
AidanLiotard:develop

Conversation

@nbereux
Copy link
Copy Markdown
Contributor

@nbereux nbereux commented May 14, 2026

No description provided.

def named_parameters(self) -> dict[str, np.ndarray]:
return {
name: tensor.detach().cpu().numpy()
for name, tensor in self.energy.state_dict().items()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use self.energy.named_parameters()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retire le du dossier

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retire le du dossier

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retire le du dossier

Comment thread nb/manifest.json
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retire le du dossier

Comment thread rbms/classes.py
Comment on lines +193 to +207
@abstractmethod
def sample_visibles(
self, chains: dict[str, Tensor], beta: float = 1.0
) -> dict[str, Tensor]:
"""Sample the visible layer.

Args:
chains (dict[str, Tensor]): The parallel chains used for sampling.
beta (float, optional): The inverse temperature. Defaults to 1.0.

Returns:
dict[str, Tensor]: The updated chains with sampled hidden states.
"""
...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

à mettre dans la classe RBM

Comment thread rbms/classes.py
Comment on lines 151 to 175
@staticmethod
@abstractmethod
def init_parameters(
num_hiddens: int,
num_visibles: int,
dataset: RBMDataset,
device: torch.device | str,
dtype: torch.dtype,
var_init: float = 1e-4,
) -> EBM:
"""Initialize the parameters of the RBM.

Args:
num_hiddens (int): Number of hidden units.
num_visibles (int): The number of visible units.
dataset (RBMDataset): Training dataset.
device (torch.device): PyTorch device for the parameters.
dtype (torch.dtype): PyTorch dtype for the parameters.
var_init (float, optional): Variance of the weight matrix. Defaults to 1e-4.

Notes:
- The number of visible units is induced from the dataset provided.
- Hidden biases are set to 0.
- Visible biases are set to the frequencies of the dataset.
- The weight matrix is initialized with a Gaussian distribution of variance `var_init`.
"""
...
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

à mettre dans la classe RBM

Comment thread rbms/classes.py
Comment on lines +218 to +219
kernel (Optional[Kernel]): The Markov kernel to use for sampling. Defaults to None.
kernel_params (Optional[dict]): The parameters for the Markov kernel. Defaults to None.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

à mettre directement dans la doc de sample_state dans les deep EBMs

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ca degage

Comment thread uv.lock
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ca dégage

@nbereux
Copy link
Copy Markdown
Contributor Author

nbereux commented May 14, 2026

dans .gitignore il faut rajouter des lignes

**/.DS_store
**/*.ipynb
**/*.json

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants