-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmodels.py
More file actions
25 lines (22 loc) · 1 KB
/
models.py
File metadata and controls
25 lines (22 loc) · 1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np
import torch
import torch.nn as nn
class ManyLossMinimaModel(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int = 20, output_dim: int = 1):
"""
A simple feed-forward model with a sine activation in the first layer to create periodicity,
followed by a tanh activation. The resulting output is a linear combination of these nonlinear
features.
"""
super(ManyLossMinimaModel, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
self.input_dim = input_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Using sine activation to induce oscillatory behavior (and hence many local minima).
# This will allow certain optimization improvements like momentum to have an advantage.
x = torch.sin(self.fc1(x))
x = torch.tanh(self.fc2(x))
x = self.fc3(x)
return x