Skip to content

Commit 54285c2

Browse files
Tutorial2: moved functs from utils.py to notebook
1 parent 72d9aa0 commit 54285c2

1 file changed

Lines changed: 48 additions & 2 deletions

File tree

Tutorial2/Tutorial2.ipynb

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,30 @@
168168
"metadata": {},
169169
"outputs": [],
170170
"source": [
171-
"from utils import VectorialDataset\n",
171+
"#%% Dataset to manage vector to vector data\n",
172+
"class VectorialDataset(torch.utils.data.Dataset):\n",
173+
" def __init__(self, input_data, output_data):\n",
174+
" super(VectorialDataset, self).__init__()\n",
175+
" self.input_data = torch.tensor(input_data.astype('f'))\n",
176+
" self.output_data = torch.tensor(output_data.astype('f'))\n",
177+
" \n",
178+
" def __len__(self):\n",
179+
" return self.input_data.shape[0]\n",
180+
" \n",
181+
" def __getitem__(self, idx):\n",
182+
" if torch.is_tensor(idx):\n",
183+
" idx = idx.tolist()\n",
184+
" sample = (self.input_data[idx, :], \n",
185+
" self.output_data[idx, :]) \n",
186+
" return sample "
187+
]
188+
},
189+
{
190+
"cell_type": "code",
191+
"execution_count": null,
192+
"metadata": {},
193+
"outputs": [],
194+
"source": [
172195
"training_set = VectorialDataset(input_data=X_train, output_data=y_train)"
173196
]
174197
},
@@ -287,7 +310,30 @@
287310
"metadata": {},
288311
"outputs": [],
289312
"source": [
290-
"from utils import LinearModel\n",
313+
"#%% Linear layer\n",
314+
"class LinearModel(nn.Module):\n",
315+
" def __init__(self, input_dim, output_dim):\n",
316+
" super(LinearModel, self).__init__()\n",
317+
"\n",
318+
" self.input_dim = input_dim\n",
319+
" self.output_dim = output_dim\n",
320+
"\n",
321+
" self.linear = nn.Linear(self.input_dim, self.output_dim, bias=True)\n",
322+
"\n",
323+
" def forward(self, x):\n",
324+
" out = self.linear(x)\n",
325+
" return out\n",
326+
" \n",
327+
" def reset(self):\n",
328+
" self.linear.reset_parameters()"
329+
]
330+
},
331+
{
332+
"cell_type": "code",
333+
"execution_count": null,
334+
"metadata": {},
335+
"outputs": [],
336+
"source": [
291337
"model = LinearModel(input_dim, output_dim)"
292338
]
293339
},

0 commit comments

Comments
 (0)