Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 5ebc6a6

Browse files
add time in dataset
1 parent 0014690 commit 5ebc6a6

3 files changed

Lines changed: 75 additions & 34 deletions

File tree

notebook/midi.ipynb

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"from torch import nn\n",
2121
"import torch.nn.functional as F\n",
2222
"\n",
23-
"from notepredictor import PitchPredictor, MIDIPitchDataset"
23+
"from notepredictor import PitchPredictor, MIDIDataset"
2424
]
2525
},
2626
{
@@ -62,68 +62,95 @@
6262
},
6363
{
6464
"cell_type": "code",
65-
"execution_count": 5,
65+
"execution_count": 11,
6666
"id": "476a319c-8fa2-4a87-a69d-dd96d94cb766",
6767
"metadata": {},
6868
"outputs": [],
6969
"source": [
7070
"batch_size = 32\n",
7171
"batch_len = 64\n",
7272
"\n",
73-
"ds = MIDIPitchDataset(data_dir, batch_len)\n",
74-
"dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=0, collate_fn=torch.stack)"
73+
"ds = MIDIDataset(data_dir, batch_len)\n",
74+
"dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=0)"
7575
]
7676
},
7777
{
7878
"cell_type": "code",
79-
"execution_count": 6,
79+
"execution_count": 12,
8080
"id": "093c8d3b-ba0a-4a05-afbf-1ec88536080d",
8181
"metadata": {},
8282
"outputs": [
8383
{
8484
"name": "stdout",
8585
"output_type": "stream",
8686
"text": [
87-
"CPU times: user 20.1 ms, sys: 14.8 ms, total: 34.9 ms\n",
88-
"Wall time: 56.2 ms\n"
87+
"CPU times: user 21.6 ms, sys: 9.01 ms, total: 30.7 ms\n",
88+
"Wall time: 30.4 ms\n"
8989
]
9090
},
9191
{
9292
"data": {
9393
"text/plain": [
94-
"torch.Size([32, 64])"
94+
"(torch.Size([32, 64]), torch.Size([32, 64]))"
9595
]
9696
},
97-
"execution_count": 6,
97+
"execution_count": 12,
9898
"metadata": {},
9999
"output_type": "execute_result"
100100
}
101101
],
102102
"source": [
103103
"%%time\n",
104104
"batch = next(iter(dl))\n",
105-
"batch.shape"
105+
"batch['pitch'].shape, batch['time'].shape"
106106
]
107107
},
108108
{
109109
"cell_type": "code",
110-
"execution_count": 7,
110+
"execution_count": 13,
111+
"id": "029c4374",
112+
"metadata": {},
113+
"outputs": [
114+
{
115+
"data": {
116+
"text/plain": [
117+
"tensor([[0.0000, 0.5278, 0.0000, ..., 0.0278, 0.0000, 0.5278],\n",
118+
" [0.0000, 0.0000, 0.0072, ..., 0.0000, 0.0000, 0.0036],\n",
119+
" [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0208],\n",
120+
" ...,\n",
121+
" [0.0000, 0.0000, 0.0000, ..., 0.1654, 0.0000, 0.0000],\n",
122+
" [0.0011, 0.0000, 0.0042, ..., 0.0042, 0.0000, 0.0000],\n",
123+
" [0.0000, 0.0240, 0.0000, ..., 0.0000, 0.0000, 0.0000]])"
124+
]
125+
},
126+
"execution_count": 13,
127+
"metadata": {},
128+
"output_type": "execute_result"
129+
}
130+
],
131+
"source": [
132+
"batch['time']"
133+
]
134+
},
135+
{
136+
"cell_type": "code",
137+
"execution_count": 21,
111138
"id": "ffd81a4e",
112139
"metadata": {},
113140
"outputs": [
114141
{
115142
"data": {
116143
"text/plain": [
117-
"<matplotlib.collections.PathCollection at 0x13c74c460>"
144+
"<matplotlib.collections.PathCollection at 0x13ec95160>"
118145
]
119146
},
120-
"execution_count": 7,
147+
"execution_count": 21,
121148
"metadata": {},
122149
"output_type": "execute_result"
123150
},
124151
{
125152
"data": {
126-
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVOUlEQVR4nO3df4xlZ1nA8e9DW3RAzbR0IduBdddkU8BWujBBcJVAK4JI7NoEAlGzUeL+YyKoQbcxkfgHdg2JkcQfSQNoEw2KWLcNjdRmV9SYCM6yRYq1Vm0Xdlq7q3RjIhss9fGPucPObu/M3rnnnnvO+57vJ9ncuefOj+c977lPT59znvdGZiJJqsvzug5AkjR7JndJqpDJXZIqZHKXpAqZ3CWpQld2HQDAtddem7t37+46DEkqyokTJ/4zM3eMe60XyX337t2srKx0HYYkFSUiTm32mmUZSaqQyV2SKmRyl6QKmdwlqUImd0mq0GXvlomIjwFvB85k5g2jbdcAfwLsBh4H3pmZT49eux14D/As8HOZeX8rkQ/I0ZOrfOj+R3ji3HmuW1zg/W+5ngP7lroOS5qbvr4H+hoXTHbm/gfAWy/Zdhg4lpl7gWOj50TEK4F3Ad89+pnfjYgrZhbtAB09ucrtd3+R1XPnSWD13Hluv/uLHD252nVo0lz09T3Q17jWXTa5Z+bfAF+9ZPOtwF2jr+8CDmzY/seZ+fXMfAz4V+C1swl1mD50/yOcf+bZi7adf+ZZPnT/Ix1FJM1XX98DfY1r3bQ195dk5pMAo8cXj7YvAV/Z8H2nR9ueIyIORcRKRKycPXt2yjDq98S589vaLtWmr++Bvsa1btYXVGPMtrGfBpKZd2bmcmYu79gxtntWwHWLC9vaLtWmr++Bvsa1btrk/lRE7AQYPZ4ZbT8NvGzD970UeGL68PT+t1zPwlUXX7ZYuOoK3v+W6zuKSJqvvr4H+hrXummT+73AwdHXB4F7Nmx/V0R8S0TsAfYCn2sW4rAd2LfEHbfdyNLiAgEsLS5wx2039uaKvNS2vr4H+hrXurjcZ6hGxMeBNwLXAk8BHwCOAp8AdgFfBt6RmV8dff+vAD8NfAN4X2b+xeWCWF5eThcOk6TtiYgTmbk87rXL3ueeme/e5KVbNvn+DwIfnDw8SdKs9WLJX3VrkkaMPjdrqFxtH1dDPm5N7gO33oixfr/ueiMG8M03wSTfI21X28fV0I9b15YZuEkaMfrerKEytX1cDf24NbkP3CSNGH1v1lCZ2j6uhn7cmtwHbpJGjL43a6hMbR9XQz9uTe4DN0kjRt+bNVSmto+roR+3XlAduPULS1vdUTDJ90jb1fZxNfTj9rJNTPNgE5Mkbd9WTUyWZSSpQpZlNFNDbhopSZN5co7LYHLXzAy9aaQUTebJOS6HZRnNzNCbRkrRZJ6c43KY3DUzQ28aKUWTeXKOy2FZRjNz3eICq2Pe5ENpGumLy9XEm8yTczw7bV+78MxdMzP0ppE+WK+Jr547T3KhJn705Oo3v6fJPDnHszHJPDVlctfM9P2TaYZgkpp4k3lyjmdjHtcuLMtopg7sW/KN3qFJa+JN5sk5bm4e1y48c5cqMvTFskoxj3kyuUsVsSZehnnMk2UZqSJDXyyrFPOYJxcOk6RCuXCYJA2MZRm1rouFpmpc3KrGMak9Jne1qouFpmpc3KrGMaldlmXUqi4Wmqpxcasax6R2mdzVqi4Wmqpxcasax6R2mdzVqi6aamps5KlxTGqXyV2t6qKppsZGnhrHpHZ5QVWt6qKppsZGnhrHpHbZxCRJhbKJSZIGxrJMh2xKuaDP+6Lt2Po89hqN299QX8nLskxHLm1KgbULZEP84IM+74u2Y+vz2Gs0bn9f9byAgGeevZALS5kDyzI9ZFPKBX3eF23H1uex12jc/n7m//KixA51zIHJvSM2pVzQ533Rdmx9HnuNtrNfS5+DRsk9It4bEQ9FxJci4n2jbddExAMR8ejo8eqZRFoZm1Iu6PO+aDu2Po+9iaMnV9l/5Dh7Dt/H/iPHZ/rBz01sZ7+WPgdTJ/eIuAH4GeC1wKuAt0fEXuAwcCwz9wLHRs91CZtSLujzvmg7tj6PfVrrde3Vc+dJLixy1ocEP25/X/W84Kor4qJtpc8BNLtb5hXA32fm1wAi4q+BHwNuBd44+p67gM8Av9zg71TJppQL+rwv2o6tz2Of1lbXEboe12b7e9y2rmNtauq7ZSLiFcA9wOuB86ydpa8AP5mZixu+7+nMfE5pJiIOAYcAdu3a9ZpTp05NFYekftlz+D7GZZUAHjvyI/MOp2qt3C2TmQ8DvwE8AHwa+ALwjW38/J2ZuZyZyzt27Jg2DEk9U+t1hNI0amLKzI8CHwWIiF8HTgNPRcTOzHwyInYCZ5qHqXmrsbHGMc3H+99y/dh790uvYZemUXKPiBdn5pmI2AXcxlqJZg9wEDgyeryncZSaqxo/9ccxzU+N1xFK1KhDNSL+FngR8AzwC5l5LCJeBHwC2AV8GXhHZn51q98zxA7VPtt/5DirY+7xXVpc4O8O39xBRM05JtVoq5p707LMD4zZ9l/ALU1+r7pVY2ONY9LQuHBYxaatx163uDD2jHDeF8RmWU+e9Zgmja3Nmnhf5kkX68t1EJcfqFSTRpI+NNbMuhFmlmOaNLa2m3n6ME+6WJ8auEzulWqyINWBfUvccduNLC0uEKzVcOe9Qt6sF9Sa5Zgmja3tRcH6ME+6WJ8WgrMsU6mm9dgD+5Y6TRJt1JNnNaZJY5tHTbzredLF+nQdxDP3SpXeSNLn+CeNrc9jUDv6NOcm90qVXo/tc/yTxtbnMagdfZpzyzKVKr2RpM/xTxpbn8egdvRpzv2YPUkqlB+zJ0kDY3KXpAqZ3CWpQiZ3SaqQyV2SKmRyl6QKmdwlqUImd0mqkMldkipkcpekCrm2TEN9+dQVSdrI5N5AXz99XpIsyzTQp09dkaSNTO4N9OlTVyRpo0GWZWZVJ/fT5yX11eDO3Gf56eR9+tQVSdpocMl9lnVyP31eUl8Nriwz6zq5nz4vqY8Gd+bep08nl6S2DC65WyeXNASDK8v06dPJJaktg0vuYJ1cUv0GV5aRpCEY5Jn7OC4AJqkmJndcAExSfSzL4AJgkupjcscFwCTVp1Fyj4ifj4gvRcRDEfHxiPjWiLgmIh6IiEdHj1fPKti22NgkqTZTJ/eIWAJ+DljOzBuAK4B3AYeBY5m5Fzg2et5rNjZJqk3TssyVwEJEXAm8AHgCuBW4a/T6XcCBhn+jdS4AJqk2kZnT/3DEe4EPAueBv8zMH4+Ic5m5uOF7ns7M55RmIuIQcAhg165drzl16tTUcUjSEEXEicxcHvfa1LdCjmrptwJ7gHPAn0bET0z685l5J3AnwPLy8vT/hZkz74eXVIIm97n/IPBYZp4FiIi7ge8DnoqInZn5ZETsBM7MIM5e8H54SaVoUnP/MvC6iHhBRARwC/AwcC9wcPQ9B4F7moXYH94PL6kUU5+5Z+ZnI+KTwOeBbwAnWSuzfBvwiYh4D2v/AXjHLALtA++Hl1SKRssPZOYHgA9csvnrrJ3FV8cPxJZUCjtUt8H74SWVwoXDtsEP+pBUCpP7NvlBH5JKYFlGkipkcpekCpncJalCJndJqpDJXZIqZHKXpAp5K+ScuJqkpHkyuc+Bq0lKmjfLMnPgapKS5s3kPgeuJilp3kzuc7DZqpGuJimpLSb3OXA1SUnz5gXVOXA1SUnzZnKfE1eTlDRPRSd37x2XpPGKTe7eOy5Jmyv2gqr3jkvS5opN7t47LkmbKza5e++4JG2u2OTuveOStLliL6h677gkba7Y5A7eOy5Jmym2LCNJ2lzRZ+66YCgNXX0e5ySx9Tl+taeLeTe5V2AoDV19HucksfU5frWnq3m3LFOBoTR09Xmck8TW5/jVnq7m3eRegaE0dPV5nJPE1uf41Z6u5t3kXoGhNHT1eZyTxNbn+NWerubd5F6BEhu6jp5cZf+R4+w5fB/7jxzn6MnVy/5Mn8c5SWx9jr9E0xxDXehq3r2gWoHSGrqmvcDU53FOEluf4y9NSRenu5r3yMxW/8AklpeXc2VlpeswNCf7jxxndUy9cWlxgb87fHMHEak0HkNrIuJEZi6Pe82yjObOC4tqymPo8qZO7hFxfUQ8uOHff0fE+yLimoh4ICIeHT1ePcuAVT4vLKopj6HLmzq5Z+YjmXlTZt4EvAb4GvDnwGHgWGbuBY6Nnkvf5IVFNeUxdHmzuqB6C/BvmXkqIm4F3jjafhfwGeCXZ/R3VAEvLKopj6HLm8kF1Yj4GPD5zPztiDiXmYsbXns6M59TmomIQ8AhgF27dr3m1KlTjeOQpCHZ6oJq4zP3iHg+8KPA7dv5ucy8E7gT1u6WaRqH5s9FsOrifNZlFmWZH2btrP2p0fOnImJnZj4ZETuBMzP4G+qZku4z1uU5n/WZxa2Q7wY+vuH5vcDB0dcHgXtm8DfUMy6CVRfnsz6NkntEvAB4M3D3hs1HgDdHxKOj1440+RvqJ+8zrovzWZ9GZZnM/Brwoku2/Rdrd8+oYtctLoztEPQ+4zI5n/WxQ3VgZrXYkvcZ12Wz+XzTy3cUsTiXnsuFwwZklhfNvM+4LuPm800v38GfnVj1ImuhXDhsQFxsSdvh8dJ/LhwmwItm2h6Pl7JVX5YprTGjzXi9aHax0o6NaU07ztKOl1nPZ+nHR9Vn7us15tVz50ku1Az7elGo7Xi9CHpBacfGtJqMs6TjZdbzWcPxUXVyL60xo+14D+xb4o7bbmRpcYFgrXZ6x203FnU2MiulHRvTajLOko6XWc9nDcdH1WWZ0mqG84j3wL6lXr455620Y2NaTcdZyvEy6/ms4fio+sy9tAX9S4u3ZEPZ145zunHWsN+qTu4l1QyhvHhLNpR97TinG2cN+63qskxpjTalxVuyoexrxzndOGvYbzYxSVKhbGKSpIGpuiwjlar0Bhp1z+Qu9YyfiqRZsCwj9UwNDTTqnsld6pkaGmjUPZO71DM1NNCoeyZ3qWdqaKBR97ygKvVMDQ006p7JXeqhUhbsUn9ZlpGkCpncJalCJndJqpDJXZIqZHKXpApVdbeMiy1J5fN9PBvVJHcXW5LK5/t4dqopy7jYklQ+38ezU01yd7ElqXy+j2enmuTuYktS+Xwfz041yd3FlqTy+T6enWouqLrYklQ+38ezE5nZdQwsLy/nyspK12FIUlEi4kRmLo97rZqyjCTpgkZlmYhYBD4C3AAk8NPAI8CfALuBx4F3ZubTTf5OadpuwrDJoz3uW22lpOOj6Zn7h4FPZ+bLgVcBDwOHgWOZuRc4Nno+GOtNGKvnzpNcaMI4enK1iN8/ZO5bbaW042Pq5B4R3wG8AfgoQGb+b2aeA24F7hp9213AgWYhlqXtJgybPNrjvtVWSjs+mpy5fxdwFvj9iDgZER+JiBcCL8nMJwFGjy8e98MRcSgiViJi5ezZsw3C6Je2mzBs8miP+1ZbKe34aJLcrwReDfxeZu4D/odtlGAy887MXM7M5R07djQIo1/absKwyeO5jp5cZf+R4+w5fB/7jxyf+n+T3bftmtU8daW046NJcj8NnM7Mz46ef5K1ZP9UROwEGD2eaRZiWdpuwrDJ42KzrIO6b9tTWr16nNKOj6mTe2b+B/CViFgf2S3APwH3AgdH2w4C9zSKsDAH9i1xx203srS4QABLiwvccduNM7ui3vbvL80s66Du2/aUVq8ep7Tjo1ETU0TcxNqtkM8H/h34Kdb+g/EJYBfwZeAdmfnVrX6PTUya1p7D9zHuCA7gsSM/Mu9wtAnnqR1bNTE1us89Mx8Exv3iW5r8XmlS1y0usDrmglZf66BD5TzNnx2qKlppddChcp7mr5qFwzRMLjRVBudp/lw4TJIK5cJhkjQwlmUkNVbSglpDYXKX1Mh6g9L6fezrDUqACb5DlmUkNVJDg1KNTO6SGiltQa2hMLlLaqS0BbWGwuQuqREblPrJC6qSGrFBqZ9M7pIaO7BvyWTeM5ZlJKlCnrlvwcYMSaUyuW/CxgxJJbMsswkbMySVzOS+CRszJJXM5L4JGzMklczkvgkbMySVzAuqm7AxQ1LJTO5bsDFDUqksy0hShUzuklQhk7skVcjkLkkVMrlLUoUiM7uOgYg4C5xq8CuuBf5zRuF0wfi7VXr8UP4YjH8635mZO8a90Ivk3lRErGTmctdxTMv4u1V6/FD+GIx/9izLSFKFTO6SVKFakvudXQfQkPF3q/T4ofwxGP+MVVFzlyRdrJYzd0nSBiZ3SapQ0ck9It4aEY9ExL9GxOGu45lERHwsIs5ExEMbtl0TEQ9ExKOjx6u7jHErEfGyiPiriHg4Ir4UEe8dbS9iDBHxrRHxuYj4wij+XxttLyL+dRFxRUScjIhPjZ6XFv/jEfHFiHgwIlZG24oZQ0QsRsQnI+KfR++F1/ct/mKTe0RcAfwO8MPAK4F3R8Qru41qIn8AvPWSbYeBY5m5Fzg2et5X3wB+MTNfAbwO+NnRfi9lDF8Hbs7MVwE3AW+NiNdRTvzr3gs8vOF5afEDvCkzb9pwf3hJY/gw8OnMfDnwKtbmol/xZ2aR/4DXA/dveH47cHvXcU0Y+27goQ3PHwF2jr7eCTzSdYzbGMs9wJtLHAPwAuDzwPeWFD/wUtaSx83Ap0o8hoDHgWsv2VbEGIDvAB5jdENKX+Mv9swdWAK+suH56dG2Er0kM58EGD2+uON4JhIRu4F9wGcpaAyjksaDwBnggcwsKn7gt4BfAv5vw7aS4gdI4C8j4kREHBptK2UM3wWcBX5/VBr7SES8kJ7FX3JyjzHbvK9zTiLi24A/A96Xmf/ddTzbkZnPZuZNrJ0BvzYibug4pIlFxNuBM5l5outYGtqfma9mraz6sxHxhq4D2oYrgVcDv5eZ+4D/oesSzBglJ/fTwMs2PH8p8ERHsTT1VETsBBg9nuk4ni1FxFWsJfY/ysy7R5uLGgNAZp4DPsPaNZBS4t8P/GhEPA78MXBzRPwh5cQPQGY+MXo8A/w58FrKGcNp4PTo//gAPslasu9V/CUn938A9kbEnoh4PvAu4N6OY5rWvcDB0dcHWatj91JEBPBR4OHM/M0NLxUxhojYERGLo68XgB8E/plC4s/M2zPzpZm5m7Vj/nhm/gSFxA8QES+MiG9f/xr4IeAhChlDZv4H8JWIuH606Rbgn+hb/F1fnGh4YeNtwL8A/wb8StfxTBjzx4EngWdYOwN4D/Ai1i6QPTp6vKbrOLeI//tZK3/9I/Dg6N/bShkD8D3AyVH8DwG/OtpeRPyXjOWNXLigWkz8rNWsvzD696X1925hY7gJWBkdR0eBq/sWv8sPSFKFSi7LSJI2YXKXpAqZ3CWpQiZ3SaqQyV2SKmRyl6QKmdwlqUL/D/uDGO0vwBf8AAAAAElFTkSuQmCC",
153+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAATUElEQVR4nO3df6zd9X3f8ecrlxtx05WajksHGOZMbby2SQrtLWVBXTMnrVlwAWXtlEhUqJ1qKWpTwhbTeKnaMU0KwtFKpFabEGFDJUpHM9eN0JiLmrkTf2B6HUMII26qlhLspL4sczc2F4x57497TK7ta9/z857zOXk+pKt7zufcz/2+zrm+L3/9/X6Ov6kqJEntedO4A0iS+mOBS1KjLHBJapQFLkmNssAlqVEXrOfGLrnkktq0adN6blKSmnfgwIGXqmr+zPF1LfBNmzaxuLi4npuUpOYl+cvVxj2EIkmNssAlqVEWuCQ1ygKXpEZZ4JLUqHVdhdKPPQcPs2vvIY4cO87lG+bYsXUzt1xzhXMHmLvn4GH+1eef5djxEwBc/JZZfuOnf3DkGXuZ92t7nuGz+7/GySpmEj74Y1fyb255x0jz9TJ3XK+hGdvLOEpZz/+NcGFhoXpZRrjn4GF27n6G4ydOvjE2NzvDJ97/jjVfPOeuPnfPwcPs+L2nOfH66T/32Zmw62d+aGQZe5n3a3ue4aEnXjjre9x63VVrlvg0v4ZmbC/jsCQ5UFULZ45P9CGUXXsPnfYCAxw/cZJdew85t8+5u/YeOusPI8CJkzXSjL3M++z+r636Pc41Pox8vcwd12toxvYyjtpEF/iRY8d7Gnfu2uPn+16jzNjLvJPn+Ffhucb73U6/c8f1GvYy14zDmTtoxlGb6AK/fMNcT+POXXv8fN9rlBl7mTeTrPq15xrvdzv9zh3Xa9jLXDMOZ+6gGUdtogt8x9bNzM3OnDY2NzvDjq2bndvn3B1bNzP7prOLcHYmI83Yy7wP/tiVq36Pc40PI18vc8f1GpqxvYyjNtGrUE6dIOjnLLNzV5976n6/Z9X7zdjLvFMnKvtZhTLNr6EZ28s4ahO9CkWS1OgqFEnSuVngktQoC1ySGmWBS1Kjui7wJDNJDiZ5pHP/6iRPJHkqyWKSa0cXU5J0pl72wG8Hnltx/x7grqq6Gvj1zn1J0jrpqsCTbARuBO5fMVzARZ3b3wUcGW40SdL5dPtGnnuBO4HvXDH2EWBvkk+y/BfBu4aaTJJ0XmvugSfZBhytqgNnPPQh4I6quhK4A/j0OeZv7xwjX1xaWho4sCRp2ZrvxEzyCeDngNeAC1k+bLIb+GlgQ1VVkgB/XVUXnfs7+U5MSepH3+/ErKqdVbWxqjYBHwC+UFW3snzM+yc6X7YF+OoQ80qS1jDIf2b1i8CnklwA/A2wfTiRJEnd6KnAq2ofsK9z+3HgR4YfSZLUDd+JKUmNssAlqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSoyxwSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqVNcFnmQmycEkj6wY+3CSQ0meTXLPaCJKklbTyxV5bgeeY/mixiT5R8DNwDur6pUkl44gnyTpHLraA0+yEbgRuH/F8IeAu6vqFYCqOjr8eJKkc+n2EMq9wJ3A6yvG3gb8eJL9Sf44yY+uNjHJ9iSLSRaXlpYGSytJesOaBZ5kG3C0qg6c8dAFwMXAdcAO4OEkOXN+Vd1XVQtVtTA/Pz+MzJIkujsGfj1wU5L3ARcCFyV5CHgR2F1VBTyZ5HXgEsDdbElaB2vugVfVzqraWFWbgA8AX6iqW4E9wBaAJG8D3gy8NLqokqSVelmFcqYHgAeSfBl4FbitszcuSVoHPRV4Ve0D9nVuvwrcOvxIkqRu+E5MSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1Kjui7wJDNJDiZ55IzxjyapJJcMP54k6Vx62QO/HXhu5UCSK4GfBF4YZihJ0tq6KvAkG4EbgfvPeOg3gTsBr4UpSeus2z3we1ku6tdPDSS5CThcVU+fb2KS7UkWkywuLS31HVSSdLo1CzzJNuBoVR1YMfYW4OPAr681v6ruq6qFqlqYn58fKKwk6Vu6uSr99cBNSd4HXAhcBPwO8Fbg6SQAG4EvJrm2qr4xqrCSpG9Zs8CraiewEyDJu4GPVtU/Wfk1SZ4HFqrqpeFHlCStxnXgktSobg6hvKGq9gH7VhnfNJw4kqRuuQcuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSoyxwSWpU1wWeZCbJwSSPdO7vSvKVJF9K8vtJNowspSTpLL3sgd8OPLfi/mPA26vqncCf0rlupiRpfXRV4Ek2AjcC958aq6o/rKrXOnefYPnK9JKkddLtHvi9wJ3A6+d4/BeAR1d7IMn2JItJFpeWlnpPKEla1ZoFnmQbcLSqDpzj8Y8DrwGfWe3xqrqvqhaqamF+fn6gsJKkb+nmqvTXAzcleR9wIXBRkoeq6tYktwHbgPdUVY0yqCTpdGvugVfVzqraWFWbgA8AX+iU9w3ArwI3VdX/G3FOSdIZBlkH/lvAdwKPJXkqyb8fUiZJUhe6OYTyhqraB+zr3P7eEeSRJHXJd2JKUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEZZ4JLUqK7/P/AkM8AicLiqtiX5buA/AZuA54F/WlX/axQh+7Xn4GF27T3EkWPHuXzDHDu2buaWa64Y+dxBtJC53+1Mej4zmnHSMq4l3V7KMsk/BxaAizoFfg/wzaq6O8nHgIur6lfP9z0WFhZqcXFx4NDd2HPwMDt3P8PxEyffGJubneET73/Hmi/eIHOnPXO/25n0fGY046RlXCnJgapaOHO8q0MoSTYCNwL3rxi+GXiwc/tB4Jau06yDXXsPnfaiARw/cZJdew+NdO4gWsjc73YmPZ8ZzThpGbvR7THwe4E7gddXjH1PVX0doPP50tUmJtmeZDHJ4tLS0iBZe3Lk2PGexoc1dxAtZO53O5Oeb9C5vTDjcEx7xm6sWeBJtgFHq+pAPxuoqvuqaqGqFubn5/v5Fn25fMNcT+PDmjuIFjL3u51Jzzfo3F6YcTimPWM3utkDvx64KcnzwO8CW5I8BPxVkssAOp+PDiXRkOzYupm52ZnTxuZmZ9ixdfNI5w6ihcz9bmfS85nRjJOWsRtrrkKpqp3AToAk7wY+WlW3JtkF3Abc3fn8B0NJNCSnThD0c/Z3kLnTnrnf7Ux6PjOacdIydqPrVShwWoFvS/K3gYeBq4AXgJ+tqm+eb/56rkKRpGlxrlUoXa8DB6iqfcC+zu3/CbxnGOEkSb3znZiS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEZ1c1HjC5M8meTpJM8muaszfnWSJ5I81bnq/LWjjytJOqWbK/K8AmypqpeTzAKPJ3kU+NfAXVX1aJL3AfcA7x5dVEnSSt1c1LiAlzt3Zzsf1fm4qDP+XcCRUQSUJK2uq2tiJpkBDgDfC/x2Ve1P8hFgb5JPsnwo5l3nmLsd2A5w1VVXDSOzJIkuT2JW1cmquhrYCFyb5O3Ah4A7qupK4A7g0+eYe19VLVTVwvz8/JBiS5J6WoVSVcdYvir9DcBtwO7OQ78HeBJTktZRN6tQ5pNs6NyeA94LfIXlY94/0fmyLcBXR5RRkrSKbo6BXwY82DkO/ibg4ap6JMkx4FNJLgD+hs5xbknS+uhmFcqXgGtWGX8c+JFRhJIkrc13YkpSoyxwSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGtXNJdUuTPJkkqeTPJvkrhWPfTjJoc74PaONKklaqZtLqr0CbKmql5PMAo8neRSYA24G3llVryS5dJRBJUmn6+aSagW83Lk72/ko4EPA3VX1Sufrjo4qpCTpbF0dA08yk+Qp4CjwWFXtB94G/HiS/Un+OMmPnmPu9iSLSRaXlpaGFlySvt11VeBVdbKqrgY2AtcmeTvLe+8XA9cBO4CHk2SVufdV1UJVLczPzw8vuSR9m+tpFUpVHQP2ATcALwK7a9mTwOvAJcMOKElaXTerUOaTbOjcngPeC3wF2ANs6Yy/DXgz8NKogkqSTtfNKpTLgAeTzLBc+A9X1SNJ3gw8kOTLwKvAbZ0TnpKkddDNKpQvAdesMv4qcOsoQkmS1uY7MSWpURa4JDXKApekRnVzErNZew4eZtfeQxw5dpzLN8yxY+tmbrnmipHPHUQLmfvdzqTnM6MZ1+N3fJiyngtHFhYWanFxcV22tefgYXbufobjJ06+MTY3O8Mn3v+ONX9Ig8yd9sz9bmfS85nRjOvxO96vJAeqauHM8ak9hLJr76HTfjgAx0+cZNfeQyOdO4gWMve7nUnPZ0Yzrsfv+LBNbYEfOXa8p/FhzR1EC5n73c6k5xt0bi/MOBwtZBy1qS3wyzfM9TQ+rLmDaCFzv9uZ9HyDzu2FGYejhYyjNrUFvmPrZuZmZ04bm5udYcfWzSOdO4gWMve7nUnPZ0Yzrsfv+LBN7SqUUyci+jnLPMjcac/c73YmPZ8ZzegqlDWs5yoUSZoW33arUCRp2lngktQoC1ySGmWBS1Kjurkiz4VJnkzydJJnk9x1xuMfTVJJvJyaJK2jbpYRvgJsqaqXk8wCjyd5tKqeSHIl8JPACyNNKUk6y5p74J2LFr/cuTvb+Ti19vA3gTtX3JckrZOujoEnmUnyFHAUeKyq9ie5CThcVU+PMqAkaXVdvROzqk4CV3euTv/7Sd4JfBz4qbXmJtkObAe46qqr+k8qSTpNT6tQquoYsA+4GXgr8HSS54GNwBeT/J1V5txXVQtVtTA/Pz9wYEnSsm5Wocx39rxJMge8FzhYVZdW1aaq2gS8CPxwVX1jlGElSd/SzSGUy4AHk8ywXPgPV9Ujo40lSVrLmgVeVV8CrlnjazYNK5AkqTu+E1OSGmWBS1KjLHBJapQFLkmNmtpLqo3TnoOHJ/5STZOecdLzgRmHxYz9s8CHbM/Bw+zc/QzHT5wE4PCx4+zc/QzARPzAYfIzTno+MOOwmHEwHkIZsl17D73xgz7l+ImT7Np7aEyJzjbpGSc9H5hxWMw4GAt8yI4cO97T+DhMesZJzwdmHBYzDsYCH7LLN8z1ND4Ok55x0vOBGYfFjIOxwIdsx9bNzM3OnDY2NzvDjq2bx5TobJOecdLzgRmHxYyD8STmkJ06qTGJZ6xPmfSMk54PzDgsZhxMqtbvYjoLCwu1uLi4btuTpGmQ5EBVLZw57iEUSWqUBS5JjbLAJalRFrgkNcoCl6RGresqlCRLwF/2Of0S4KUhxplE0/4cp/35wfQ/x2l/fjCZz/HvVtVZV4Vf1wIfRJLF1ZbRTJNpf47T/vxg+p/jtD8/aOs5eghFkhplgUtSo1oq8PvGHWAdTPtznPbnB9P/HKf9+UFDz7GZY+CSpNO1tAcuSVrBApekRjVR4EluSHIoyZ8l+di48wxTkiuT/LckzyV5Nsnt4840CklmkhxM8si4s4xCkg1JPpfkK52f5T8Yd6ZhS3JH58/ol5N8NsmF4840iCQPJDma5Msrxr47yWNJvtr5fPE4M65l4gs8yQzw28A/Bn4A+GCSHxhvqqF6DfgXVfX9wHXAL03Z8zvlduC5cYcYoU8B/7Wq/j7wQ0zZc01yBfArwEJVvR2YAT4w3lQD+4/ADWeMfQz4o6r6PuCPOvcn1sQXOHAt8GdV9edV9Srwu8DNY840NFX19ar6Yuf2/2H5F3/8/1P8ECXZCNwI3D/uLKOQ5CLgHwKfBqiqV6vq2FhDjcYFwFySC4C3AEfGnGcgVfXfgW+eMXwz8GDn9oPALeuZqVctFPgVwNdW3H+RKSu4U5JsAq4B9o85yrDdC9wJvD7mHKPy94Al4D90DhPdn+Q7xh1qmKrqMPBJ4AXg68BfV9UfjjfVSHxPVX0dlneugEvHnOe8WijwrDI2dWsfk/wt4D8DH6mq/z3uPMOSZBtwtKoOjDvLCF0A/DDw76rqGuD/MuH/9O5V51jwzcBbgcuB70hy63hTqYUCfxG4csX9jTT+T7czJZllubw/U1W7x51nyK4HbkryPMuHv7YkeWi8kYbuReDFqjr1L6fPsVzo0+S9wF9U1VJVnQB2A+8ac6ZR+KsklwF0Ph8dc57zaqHA/wT4viRvTfJmlk+cfH7MmYYmSVg+dvpcVf3bcecZtqraWVUbq2oTyz+7L1TVVO25VdU3gK8lOXWZ8vcA/2OMkUbhBeC6JG/p/Jl9D1N2orbj88Btndu3AX8wxixrmvir0lfVa0l+GdjL8pnvB6rq2THHGqbrgZ8DnknyVGfsX1bVfxlfJPXhw8BnOjsZfw78/JjzDFVV7U/yOeCLLK+cOkhDbzlfTZLPAu8GLknyIvAbwN3Aw0n+Gct/af3s+BKuzbfSS1KjWjiEIklahQUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGvX/Acmj+IjC2rUHAAAAAElFTkSuQmCC",
127154
"text/plain": [
128155
"<Figure size 432x288 with 1 Axes>"
129156
]
@@ -135,12 +162,13 @@
135162
}
136163
],
137164
"source": [
138-
"plt.scatter(range(batch.shape[1]), batch[7])"
165+
"i = 0\n",
166+
"plt.scatter(batch['time'][i].cumsum(0), batch['pitch'][i])"
139167
]
140168
},
141169
{
142170
"cell_type": "code",
143-
"execution_count": 8,
171+
"execution_count": 14,
144172
"id": "ea23aea6-1629-4d1c-b1be-0865f8a07e9d",
145173
"metadata": {},
146174
"outputs": [],
@@ -239,7 +267,7 @@
239267
"for batch in tqdm(it.islice(dl,512)):\n",
240268
" # batch = torch.LongTensor([notes for notes in it.islice(gen_tracks(batch_len), batch_size)])\n",
241269
" opt.zero_grad()\n",
242-
" r = net(batch)\n",
270+
" r = net(batch['pitch'])\n",
243271
" nll = (-r['log_probs']).mean()\n",
244272
" nll.backward()\n",
245273
" opt.step()\n",
@@ -407,7 +435,8 @@
407435
],
408436
"source": [
409437
"counts = torch.zeros(130,130).long()\n",
410-
"for s in tqdm(ds):\n",
438+
"for item in tqdm(ds):\n",
439+
" s = item['pitch']\n",
411440
" counts[s[:-1], s[1:]] += 1"
412441
]
413442
},
@@ -564,7 +593,8 @@
564593
],
565594
"source": [
566595
"counts = torch.zeros(130,130,130).long()\n",
567-
"for s in tqdm(ds):\n",
596+
"for item in tqdm(ds):\n",
597+
" s = item['pitch']\n",
568598
" counts[s[:-2], s[1:-1], s[2:]] += 1"
569599
]
570600
},

0 commit comments

Comments
 (0)