|
168 | 168 | "metadata": {}, |
169 | 169 | "outputs": [], |
170 | 170 | "source": [ |
171 | | - "from utils import VectorialDataset\n", |
| 171 | + "#%% Dataset to manage vector to vector data\n", |
| 172 | + "class VectorialDataset(torch.utils.data.Dataset):\n", |
| 173 | + " def __init__(self, input_data, output_data):\n", |
| 174 | + " super(VectorialDataset, self).__init__()\n", |
| 175 | + " self.input_data = torch.tensor(input_data.astype('f'))\n", |
| 176 | + " self.output_data = torch.tensor(output_data.astype('f'))\n", |
| 177 | + " \n", |
| 178 | + " def __len__(self):\n", |
| 179 | + " return self.input_data.shape[0]\n", |
| 180 | + " \n", |
| 181 | + " def __getitem__(self, idx):\n", |
| 182 | + " if torch.is_tensor(idx):\n", |
| 183 | + " idx = idx.tolist()\n", |
| 184 | + " sample = (self.input_data[idx, :], \n", |
| 185 | + " self.output_data[idx, :]) \n", |
| 186 | + " return sample " |
| 187 | + ] |
| 188 | + }, |
| 189 | + { |
| 190 | + "cell_type": "code", |
| 191 | + "execution_count": null, |
| 192 | + "metadata": {}, |
| 193 | + "outputs": [], |
| 194 | + "source": [ |
172 | 195 | "training_set = VectorialDataset(input_data=X_train, output_data=y_train)" |
173 | 196 | ] |
174 | 197 | }, |
|
287 | 310 | "metadata": {}, |
288 | 311 | "outputs": [], |
289 | 312 | "source": [ |
290 | | - "from utils import LinearModel\n", |
| 313 | + "#%% Linear layer\n", |
| 314 | + "class LinearModel(nn.Module):\n", |
| 315 | + " def __init__(self, input_dim, output_dim):\n", |
| 316 | + " super(LinearModel, self).__init__()\n", |
| 317 | + "\n", |
| 318 | + " self.input_dim = input_dim\n", |
| 319 | + " self.output_dim = output_dim\n", |
| 320 | + "\n", |
| 321 | + " self.linear = nn.Linear(self.input_dim, self.output_dim, bias=True)\n", |
| 322 | + "\n", |
| 323 | + " def forward(self, x):\n", |
| 324 | + " out = self.linear(x)\n", |
| 325 | + " return out\n", |
| 326 | + " \n", |
| 327 | + " def reset(self):\n", |
| 328 | + " self.linear.reset_parameters()" |
| 329 | + ] |
| 330 | + }, |
| 331 | + { |
| 332 | + "cell_type": "code", |
| 333 | + "execution_count": null, |
| 334 | + "metadata": {}, |
| 335 | + "outputs": [], |
| 336 | + "source": [ |
291 | 337 | "model = LinearModel(input_dim, output_dim)" |
292 | 338 | ] |
293 | 339 | }, |
|
0 commit comments