|
20 | 20 | "from torch import nn\n", |
21 | 21 | "import torch.nn.functional as F\n", |
22 | 22 | "\n", |
23 | | - "from notepredictor import PitchPredictor, MIDIPitchDataset" |
| 23 | + "from notepredictor import PitchPredictor, MIDIDataset" |
24 | 24 | ] |
25 | 25 | }, |
26 | 26 | { |
|
62 | 62 | }, |
63 | 63 | { |
64 | 64 | "cell_type": "code", |
65 | | - "execution_count": 5, |
| 65 | + "execution_count": 11, |
66 | 66 | "id": "476a319c-8fa2-4a87-a69d-dd96d94cb766", |
67 | 67 | "metadata": {}, |
68 | 68 | "outputs": [], |
69 | 69 | "source": [ |
70 | 70 | "batch_size = 32\n", |
71 | 71 | "batch_len = 64\n", |
72 | 72 | "\n", |
73 | | - "ds = MIDIPitchDataset(data_dir, batch_len)\n", |
74 | | - "dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=0, collate_fn=torch.stack)" |
| 73 | + "ds = MIDIDataset(data_dir, batch_len)\n", |
| 74 | + "dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=0)" |
75 | 75 | ] |
76 | 76 | }, |
77 | 77 | { |
78 | 78 | "cell_type": "code", |
79 | | - "execution_count": 6, |
| 79 | + "execution_count": 12, |
80 | 80 | "id": "093c8d3b-ba0a-4a05-afbf-1ec88536080d", |
81 | 81 | "metadata": {}, |
82 | 82 | "outputs": [ |
83 | 83 | { |
84 | 84 | "name": "stdout", |
85 | 85 | "output_type": "stream", |
86 | 86 | "text": [ |
87 | | - "CPU times: user 20.1 ms, sys: 14.8 ms, total: 34.9 ms\n", |
88 | | - "Wall time: 56.2 ms\n" |
| 87 | + "CPU times: user 21.6 ms, sys: 9.01 ms, total: 30.7 ms\n", |
| 88 | + "Wall time: 30.4 ms\n" |
89 | 89 | ] |
90 | 90 | }, |
91 | 91 | { |
92 | 92 | "data": { |
93 | 93 | "text/plain": [ |
94 | | - "torch.Size([32, 64])" |
| 94 | + "(torch.Size([32, 64]), torch.Size([32, 64]))" |
95 | 95 | ] |
96 | 96 | }, |
97 | | - "execution_count": 6, |
| 97 | + "execution_count": 12, |
98 | 98 | "metadata": {}, |
99 | 99 | "output_type": "execute_result" |
100 | 100 | } |
101 | 101 | ], |
102 | 102 | "source": [ |
103 | 103 | "%%time\n", |
104 | 104 | "batch = next(iter(dl))\n", |
105 | | - "batch.shape" |
| 105 | + "batch['pitch'].shape, batch['time'].shape" |
106 | 106 | ] |
107 | 107 | }, |
108 | 108 | { |
109 | 109 | "cell_type": "code", |
110 | | - "execution_count": 7, |
| 110 | + "execution_count": 13, |
| 111 | + "id": "029c4374", |
| 112 | + "metadata": {}, |
| 113 | + "outputs": [ |
| 114 | + { |
| 115 | + "data": { |
| 116 | + "text/plain": [ |
| 117 | + "tensor([[0.0000, 0.5278, 0.0000, ..., 0.0278, 0.0000, 0.5278],\n", |
| 118 | + " [0.0000, 0.0000, 0.0072, ..., 0.0000, 0.0000, 0.0036],\n", |
| 119 | + " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0208],\n", |
| 120 | + " ...,\n", |
| 121 | + " [0.0000, 0.0000, 0.0000, ..., 0.1654, 0.0000, 0.0000],\n", |
| 122 | + " [0.0011, 0.0000, 0.0042, ..., 0.0042, 0.0000, 0.0000],\n", |
| 123 | + " [0.0000, 0.0240, 0.0000, ..., 0.0000, 0.0000, 0.0000]])" |
| 124 | + ] |
| 125 | + }, |
| 126 | + "execution_count": 13, |
| 127 | + "metadata": {}, |
| 128 | + "output_type": "execute_result" |
| 129 | + } |
| 130 | + ], |
| 131 | + "source": [ |
| 132 | + "batch['time']" |
| 133 | + ] |
| 134 | + }, |
| 135 | + { |
| 136 | + "cell_type": "code", |
| 137 | + "execution_count": 21, |
111 | 138 | "id": "ffd81a4e", |
112 | 139 | "metadata": {}, |
113 | 140 | "outputs": [ |
114 | 141 | { |
115 | 142 | "data": { |
116 | 143 | "text/plain": [ |
117 | | - "<matplotlib.collections.PathCollection at 0x13c74c460>" |
| 144 | + "<matplotlib.collections.PathCollection at 0x13ec95160>" |
118 | 145 | ] |
119 | 146 | }, |
120 | | - "execution_count": 7, |
| 147 | + "execution_count": 21, |
121 | 148 | "metadata": {}, |
122 | 149 | "output_type": "execute_result" |
123 | 150 | }, |
124 | 151 | { |
125 | 152 | "data": { |
126 | | - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVOUlEQVR4nO3df4xlZ1nA8e9DW3RAzbR0IduBdddkU8BWujBBcJVAK4JI7NoEAlGzUeL+YyKoQbcxkfgHdg2JkcQfSQNoEw2KWLcNjdRmV9SYCM6yRYq1Vm0Xdlq7q3RjIhss9fGPucPObu/M3rnnnnvO+57vJ9ncuefOj+c977lPT59znvdGZiJJqsvzug5AkjR7JndJqpDJXZIqZHKXpAqZ3CWpQld2HQDAtddem7t37+46DEkqyokTJ/4zM3eMe60XyX337t2srKx0HYYkFSUiTm32mmUZSaqQyV2SKmRyl6QKmdwlqUImd0mq0GXvlomIjwFvB85k5g2jbdcAfwLsBh4H3pmZT49eux14D/As8HOZeX8rkQ/I0ZOrfOj+R3ji3HmuW1zg/W+5ngP7lroOS5qbvr4H+hoXTHbm/gfAWy/Zdhg4lpl7gWOj50TEK4F3Ad89+pnfjYgrZhbtAB09ucrtd3+R1XPnSWD13Hluv/uLHD252nVo0lz09T3Q17jWXTa5Z+bfAF+9ZPOtwF2jr+8CDmzY/seZ+fXMfAz4V+C1swl1mD50/yOcf+bZi7adf+ZZPnT/Ix1FJM1XX98DfY1r3bQ195dk5pMAo8cXj7YvAV/Z8H2nR9ueIyIORcRKRKycPXt2yjDq98S589vaLtWmr++Bvsa1btYXVGPMtrGfBpKZd2bmcmYu79gxtntWwHWLC9vaLtWmr++Bvsa1btrk/lRE7AQYPZ4ZbT8NvGzD970UeGL68PT+t1zPwlUXX7ZYuOoK3v+W6zuKSJqvvr4H+hrXummT+73AwdHXB4F7Nmx/V0R8S0TsAfYCn2sW4rAd2LfEHbfdyNLiAgEsLS5wx2039uaKvNS2vr4H+hrXurjcZ6hGxMeBNwLXAk8BHwCOAp8AdgFfBt6RmV8dff+vAD8NfAN4X2b+xeWCWF5eThcOk6TtiYgTmbk87rXL3ueeme/e5KVbNvn+DwIfnDw8SdKs9WLJX3VrkkaMPjdrqFxtH1dDPm5N7gO33oixfr/ueiMG8M03wSTfI21X28fV0I9b15YZuEkaMfrerKEytX1cDf24NbkP3CSNGH1v1lCZ2j6uhn7cmtwHbpJGjL43a6hMbR9XQz9uTe4DN0kjRt+bNVSmto+roR+3XlAduPULS1vdUTDJ90jb1fZxNfTj9rJNTPNgE5Mkbd9WTUyWZSSpQpZlNFNDbhopSZN5co7LYHLXzAy9aaQUTebJOS6HZRnNzNCbRkrRZJ6c43KY3DUzQ28aKUWTeXKOy2FZRjNz3eICq2Pe5ENpGumLy9XEm8yTczw7bV+78MxdMzP0ppE+WK+Jr547T3KhJn705Oo3v6fJPDnHszHJPDVlctfM9P2TaYZgkpp4k3lyjmdjHtcuLMtopg7sW/KN3qFJa+JN5sk5bm4e1y48c5cqMvTFskoxj3kyuUsVsSZehnnMk2UZqSJDXyyrFPOYJxcOk6RCuXCYJA2MZRm1rouFpmpc3KrGMak9Jne1qouFpmpc3KrGMaldlmXUqi4Wmqpxcasax6R2mdzVqi4Wmqpxcasax6R2mdzVqi6aamps5KlxTGqXyV2t6qKppsZGnhrHpHZ5QVWt6qKppsZGnhrHpHbZxCRJhbKJSZIGxrJMh2xKuaDP+6Lt2Po89hqN299QX8nLskxHLm1KgbULZEP84IM+74u2Y+vz2Gs0bn9f9byAgGeevZALS5kDyzI9ZFPKBX3eF23H1uex12jc/n7m//KixA51zIHJvSM2pVzQ533Rdmx9HnuNtrNfS5+DRsk9It4bEQ9FxJci4n2jbddExAMR8ejo8eqZRFoZm1Iu6PO+aDu2Po+9iaMnV9l/5Dh7Dt/H/iPHZ/rBz01sZ7+WPgdTJ/eIuAH4GeC1wKuAt0fEXuAwcCwz9wLHRs91CZtSLujzvmg7tj6PfVrrde3Vc+dJLixy1ocEP25/X/W84Kor4qJtpc8BNLtb5hXA32fm1wAi4q+BHwNuBd44+p67gM8Av9zg71TJppQL+rwv2o6tz2Of1lbXEboe12b7e9y2rmNtauq7ZSLiFcA9wOuB86ydpa8AP5mZixu+7+nMfE5pJiIOAYcAdu3a9ZpTp05NFYekftlz+D7GZZUAHjvyI/MOp2qt3C2TmQ8DvwE8AHwa+ALwjW38/J2ZuZyZyzt27Jg2DEk9U+t1hNI0amLKzI8CHwWIiF8HTgNPRcTOzHwyInYCZ5qHqXmrsbHGMc3H+99y/dh790uvYZemUXKPiBdn5pmI2AXcxlqJZg9wEDgyeryncZSaqxo/9ccxzU+N1xFK1KhDNSL+FngR8AzwC5l5LCJeBHwC2AV8GXhHZn51q98zxA7VPtt/5DirY+7xXVpc4O8O39xBRM05JtVoq5p707LMD4zZ9l/ALU1+r7pVY2ONY9LQuHBYxaatx163uDD2jHDeF8RmWU+e9Zgmja3Nmnhf5kkX68t1EJcfqFSTRpI+NNbMuhFmlmOaNLa2m3n6ME+6WJ8auEzulWqyINWBfUvccduNLC0uEKzVcOe9Qt6sF9Sa5Zgmja3tRcH6ME+6WJ8WgrMsU6mm9dgD+5Y6TRJt1JNnNaZJY5tHTbzredLF+nQdxDP3SpXeSNLn+CeNrc9jUDv6NOcm90qVXo/tc/yTxtbnMagdfZpzyzKVKr2RpM/xTxpbn8egdvRpzv2YPUkqlB+zJ0kDY3KXpAqZ3CWpQiZ3SaqQyV2SKmRyl6QKmdwlqUImd0mqkMldkipkcpekCrm2TEN9+dQVSdrI5N5AXz99XpIsyzTQp09dkaSNTO4N9OlTVyRpo0GWZWZVJ/fT5yX11eDO3Gf56eR9+tQVSdpocMl9lnVyP31eUl8Nriwz6zq5nz4vqY8Gd+bep08nl6S2DC65WyeXNASDK8v06dPJJaktg0vuYJ1cUv0GV5aRpCEY5Jn7OC4AJqkmJndcAExSfSzL4AJgkupjcscFwCTVp1Fyj4ifj4gvRcRDEfHxiPjWiLgmIh6IiEdHj1fPKti22NgkqTZTJ/eIWAJ+DljOzBuAK4B3AYeBY5m5Fzg2et5rNjZJqk3TssyVwEJEXAm8AHgCuBW4a/T6XcCBhn+jdS4AJqk2kZnT/3DEe4EPAueBv8zMH4+Ic5m5uOF7ns7M55RmIuIQcAhg165drzl16tTUcUjSEEXEicxcHvfa1LdCjmrptwJ7gHPAn0bET0z685l5J3AnwPLy8vT/hZkz74eXVIIm97n/IPBYZp4FiIi7ge8DnoqInZn5ZETsBM7MIM5e8H54SaVoUnP/MvC6iHhBRARwC/AwcC9wcPQ9B4F7moXYH94PL6kUU5+5Z+ZnI+KTwOeBbwAnWSuzfBvwiYh4D2v/AXjHLALtA++Hl1SKRssPZOYHgA9csvnrrJ3FV8cPxJZUCjtUt8H74SWVwoXDtsEP+pBUCpP7NvlBH5JKYFlGkipkcpekCpncJalCJndJqpDJXZIqZHKXpAp5K+ScuJqkpHkyuc+Bq0lKmjfLMnPgapKS5s3kPgeuJilp3kzuc7DZqpGuJimpLSb3OXA1SUnz5gXVOXA1SUnzZnKfE1eTlDRPRSd37x2XpPGKTe7eOy5Jmyv2gqr3jkvS5opN7t47LkmbKza5e++4JG2u2OTuveOStLliL6h677gkba7Y5A7eOy5Jmym2LCNJ2lzRZ+66YCgNXX0e5ySx9Tl+taeLeTe5V2AoDV19HucksfU5frWnq3m3LFOBoTR09Xmck8TW5/jVnq7m3eRegaE0dPV5nJPE1uf41Z6u5t3kXoGhNHT1eZyTxNbn+NWerubd5F6BEhu6jp5cZf+R4+w5fB/7jxzn6MnVy/5Mn8c5SWx9jr9E0xxDXehq3r2gWoHSGrqmvcDU53FOEluf4y9NSRenu5r3yMxW/8AklpeXc2VlpeswNCf7jxxndUy9cWlxgb87fHMHEak0HkNrIuJEZi6Pe82yjObOC4tqymPo8qZO7hFxfUQ8uOHff0fE+yLimoh4ICIeHT1ePcuAVT4vLKopj6HLmzq5Z+YjmXlTZt4EvAb4GvDnwGHgWGbuBY6Nnkvf5IVFNeUxdHmzuqB6C/BvmXkqIm4F3jjafhfwGeCXZ/R3VAEvLKopj6HLm8kF1Yj4GPD5zPztiDiXmYsbXns6M59TmomIQ8AhgF27dr3m1KlTjeOQpCHZ6oJq4zP3iHg+8KPA7dv5ucy8E7gT1u6WaRqH5s9FsOrifNZlFmWZH2btrP2p0fOnImJnZj4ZETuBMzP4G+qZku4z1uU5n/WZxa2Q7wY+vuH5vcDB0dcHgXtm8DfUMy6CVRfnsz6NkntEvAB4M3D3hs1HgDdHxKOj1440+RvqJ+8zrovzWZ9GZZnM/Brwoku2/Rdrd8+oYtctLoztEPQ+4zI5n/WxQ3VgZrXYkvcZ12Wz+XzTy3cUsTiXnsuFwwZklhfNvM+4LuPm800v38GfnVj1ImuhXDhsQFxsSdvh8dJ/LhwmwItm2h6Pl7JVX5YprTGjzXi9aHax0o6NaU07ztKOl1nPZ+nHR9Vn7us15tVz50ku1Az7elGo7Xi9CHpBacfGtJqMs6TjZdbzWcPxUXVyL60xo+14D+xb4o7bbmRpcYFgrXZ6x203FnU2MiulHRvTajLOko6XWc9nDcdH1WWZ0mqG84j3wL6lXr455620Y2NaTcdZyvEy6/ms4fio+sy9tAX9S4u3ZEPZ145zunHWsN+qTu4l1QyhvHhLNpR97TinG2cN+63qskxpjTalxVuyoexrxzndOGvYbzYxSVKhbGKSpIGpuiwjlar0Bhp1z+Qu9YyfiqRZsCwj9UwNDTTqnsld6pkaGmjUPZO71DM1NNCoeyZ3qWdqaKBR97ygKvVMDQ006p7JXeqhUhbsUn9ZlpGkCpncJalCJndJqpDJXZIqZHKXpApVdbeMiy1J5fN9PBvVJHcXW5LK5/t4dqopy7jYklQ+38ezU01yd7ElqXy+j2enmuTuYktS+Xwfz041yd3FlqTy+T6enWouqLrYklQ+38ezE5nZdQwsLy/nyspK12FIUlEi4kRmLo97rZqyjCTpgkZlmYhYBD4C3AAk8NPAI8CfALuBx4F3ZubTTf5OadpuwrDJoz3uW22lpOOj6Zn7h4FPZ+bLgVcBDwOHgWOZuRc4Nno+GOtNGKvnzpNcaMI4enK1iN8/ZO5bbaW042Pq5B4R3wG8AfgoQGb+b2aeA24F7hp9213AgWYhlqXtJgybPNrjvtVWSjs+mpy5fxdwFvj9iDgZER+JiBcCL8nMJwFGjy8e98MRcSgiViJi5ezZsw3C6Je2mzBs8miP+1ZbKe34aJLcrwReDfxeZu4D/odtlGAy887MXM7M5R07djQIo1/absKwyeO5jp5cZf+R4+w5fB/7jxyf+n+T3bftmtU8daW046NJcj8NnM7Mz46ef5K1ZP9UROwEGD2eaRZiWdpuwrDJ42KzrIO6b9tTWr16nNKOj6mTe2b+B/CViFgf2S3APwH3AgdH2w4C9zSKsDAH9i1xx203srS4QABLiwvccduNM7ui3vbvL80s66Du2/aUVq8ep7Tjo1ETU0TcxNqtkM8H/h34Kdb+g/EJYBfwZeAdmfnVrX6PTUya1p7D9zHuCA7gsSM/Mu9wtAnnqR1bNTE1us89Mx8Exv3iW5r8XmlS1y0usDrmglZf66BD5TzNnx2qKlppddChcp7mr5qFwzRMLjRVBudp/lw4TJIK5cJhkjQwlmUkNVbSglpDYXKX1Mh6g9L6fezrDUqACb5DlmUkNVJDg1KNTO6SGiltQa2hMLlLaqS0BbWGwuQuqREblPrJC6qSGrFBqZ9M7pIaO7BvyWTeM5ZlJKlCnrlvwcYMSaUyuW/CxgxJJbMsswkbMySVzOS+CRszJJXM5L4JGzMklczkvgkbMySVzAuqm7AxQ1LJTO5bsDFDUqksy0hShUzuklQhk7skVcjkLkkVMrlLUoUiM7uOgYg4C5xq8CuuBf5zRuF0wfi7VXr8UP4YjH8635mZO8a90Ivk3lRErGTmctdxTMv4u1V6/FD+GIx/9izLSFKFTO6SVKFakvudXQfQkPF3q/T4ofwxGP+MVVFzlyRdrJYzd0nSBiZ3SapQ0ck9It4aEY9ExL9GxOGu45lERHwsIs5ExEMbtl0TEQ9ExKOjx6u7jHErEfGyiPiriHg4Ir4UEe8dbS9iDBHxrRHxuYj4wij+XxttLyL+dRFxRUScjIhPjZ6XFv/jEfHFiHgwIlZG24oZQ0QsRsQnI+KfR++F1/ct/mKTe0RcAfwO8MPAK4F3R8Qru41qIn8AvPWSbYeBY5m5Fzg2et5X3wB+MTNfAbwO+NnRfi9lDF8Hbs7MVwE3AW+NiNdRTvzr3gs8vOF5afEDvCkzb9pwf3hJY/gw8OnMfDnwKtbmol/xZ2aR/4DXA/dveH47cHvXcU0Y+27goQ3PHwF2jr7eCTzSdYzbGMs9wJtLHAPwAuDzwPeWFD/wUtaSx83Ap0o8hoDHgWsv2VbEGIDvAB5jdENKX+Mv9swdWAK+suH56dG2Er0kM58EGD2+uON4JhIRu4F9wGcpaAyjksaDwBnggcwsKn7gt4BfAv5vw7aS4gdI4C8j4kREHBptK2UM3wWcBX5/VBr7SES8kJ7FX3JyjzHbvK9zTiLi24A/A96Xmf/ddTzbkZnPZuZNrJ0BvzYibug4pIlFxNuBM5l5outYGtqfma9mraz6sxHxhq4D2oYrgVcDv5eZ+4D/oesSzBglJ/fTwMs2PH8p8ERHsTT1VETsBBg9nuk4ni1FxFWsJfY/ysy7R5uLGgNAZp4DPsPaNZBS4t8P/GhEPA78MXBzRPwh5cQPQGY+MXo8A/w58FrKGcNp4PTo//gAPslasu9V/CUn938A9kbEnoh4PvAu4N6OY5rWvcDB0dcHWatj91JEBPBR4OHM/M0NLxUxhojYERGLo68XgB8E/plC4s/M2zPzpZm5m7Vj/nhm/gSFxA8QES+MiG9f/xr4IeAhChlDZv4H8JWIuH606Rbgn+hb/F1fnGh4YeNtwL8A/wb8StfxTBjzx4EngWdYOwN4D/Ai1i6QPTp6vKbrOLeI//tZK3/9I/Dg6N/bShkD8D3AyVH8DwG/OtpeRPyXjOWNXLigWkz8rNWsvzD696X1925hY7gJWBkdR0eBq/sWv8sPSFKFSi7LSJI2YXKXpAqZ3CWpQiZ3SaqQyV2SKmRyl6QKmdwlqUL/D/uDGO0vwBf8AAAAAElFTkSuQmCC", |
| 153 | + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAATUElEQVR4nO3df6zd9X3f8ecrlxtx05WajksHGOZMbby2SQrtLWVBXTMnrVlwAWXtlEhUqJ1qKWpTwhbTeKnaMU0KwtFKpFabEGFDJUpHM9eN0JiLmrkTf2B6HUMII26qlhLspL4sczc2F4x57497TK7ta9/z857zOXk+pKt7zufcz/2+zrm+L3/9/X6Ov6kqJEntedO4A0iS+mOBS1KjLHBJapQFLkmNssAlqVEXrOfGLrnkktq0adN6blKSmnfgwIGXqmr+zPF1LfBNmzaxuLi4npuUpOYl+cvVxj2EIkmNssAlqVEWuCQ1ygKXpEZZ4JLUqHVdhdKPPQcPs2vvIY4cO87lG+bYsXUzt1xzhXMHmLvn4GH+1eef5djxEwBc/JZZfuOnf3DkGXuZ92t7nuGz+7/GySpmEj74Y1fyb255x0jz9TJ3XK+hGdvLOEpZz/+NcGFhoXpZRrjn4GF27n6G4ydOvjE2NzvDJ97/jjVfPOeuPnfPwcPs+L2nOfH66T/32Zmw62d+aGQZe5n3a3ue4aEnXjjre9x63VVrlvg0v4ZmbC/jsCQ5UFULZ45P9CGUXXsPnfYCAxw/cZJdew85t8+5u/YeOusPI8CJkzXSjL3M++z+r636Pc41Pox8vcwd12toxvYyjtpEF/iRY8d7Gnfu2uPn+16jzNjLvJPn+Ffhucb73U6/c8f1GvYy14zDmTtoxlGb6AK/fMNcT+POXXv8fN9rlBl7mTeTrPq15xrvdzv9zh3Xa9jLXDMOZ+6gGUdtogt8x9bNzM3OnDY2NzvDjq2bndvn3B1bNzP7prOLcHYmI83Yy7wP/tiVq36Pc40PI18vc8f1GpqxvYyjNtGrUE6dIOjnLLNzV5976n6/Z9X7zdjLvFMnKvtZhTLNr6EZ28s4ahO9CkWS1OgqFEnSuVngktQoC1ySGmWBS1Kjui7wJDNJDiZ5pHP/6iRPJHkqyWKSa0cXU5J0pl72wG8Hnltx/x7grqq6Gvj1zn1J0jrpqsCTbARuBO5fMVzARZ3b3wUcGW40SdL5dPtGnnuBO4HvXDH2EWBvkk+y/BfBu4aaTJJ0XmvugSfZBhytqgNnPPQh4I6quhK4A/j0OeZv7xwjX1xaWho4sCRp2ZrvxEzyCeDngNeAC1k+bLIb+GlgQ1VVkgB/XVUXnfs7+U5MSepH3+/ErKqdVbWxqjYBHwC+UFW3snzM+yc6X7YF+OoQ80qS1jDIf2b1i8CnklwA/A2wfTiRJEnd6KnAq2ofsK9z+3HgR4YfSZLUDd+JKUmNssAlqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSoyxwSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqVNcFnmQmycEkj6wY+3CSQ0meTXLPaCJKklbTyxV5bgeeY/mixiT5R8DNwDur6pUkl44gnyTpHLraA0+yEbgRuH/F8IeAu6vqFYCqOjr8eJKkc+n2EMq9wJ3A6yvG3gb8eJL9Sf44yY+uNjHJ9iSLSRaXlpYGSytJesOaBZ5kG3C0qg6c8dAFwMXAdcAO4OEkOXN+Vd1XVQtVtTA/Pz+MzJIkujsGfj1wU5L3ARcCFyV5CHgR2F1VBTyZ5HXgEsDdbElaB2vugVfVzqraWFWbgA8AX6iqW4E9wBaAJG8D3gy8NLqokqSVelmFcqYHgAeSfBl4FbitszcuSVoHPRV4Ve0D9nVuvwrcOvxIkqRu+E5MSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1Kjui7wJDNJDiZ55IzxjyapJJcMP54k6Vx62QO/HXhu5UCSK4GfBF4YZihJ0tq6KvAkG4EbgfvPeOg3gTsBr4UpSeus2z3we1ku6tdPDSS5CThcVU+fb2KS7UkWkywuLS31HVSSdLo1CzzJNuBoVR1YMfYW4OPAr681v6ruq6qFqlqYn58fKKwk6Vu6uSr99cBNSd4HXAhcBPwO8Fbg6SQAG4EvJrm2qr4xqrCSpG9Zs8CraiewEyDJu4GPVtU/Wfk1SZ4HFqrqpeFHlCStxnXgktSobg6hvKGq9gH7VhnfNJw4kqRuuQcuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSoyxwSWpU1wWeZCbJwSSPdO7vSvKVJF9K8vtJNowspSTpLL3sgd8OPLfi/mPA26vqncCf0rlupiRpfXRV4Ek2AjcC958aq6o/rKrXOnefYPnK9JKkddLtHvi9wJ3A6+d4/BeAR1d7IMn2JItJFpeWlnpPKEla1ZoFnmQbcLSqDpzj8Y8DrwGfWe3xqrqvqhaqamF+fn6gsJKkb+nmqvTXAzcleR9wIXBRkoeq6tYktwHbgPdUVY0yqCTpdGvugVfVzqraWFWbgA8AX+iU9w3ArwI3VdX/G3FOSdIZBlkH/lvAdwKPJXkqyb8fUiZJUhe6OYTyhqraB+zr3P7eEeSRJHXJd2JKUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEZZ4JLUqK7/P/AkM8AicLiqtiX5buA/AZuA54F/WlX/axQh+7Xn4GF27T3EkWPHuXzDHDu2buaWa64Y+dxBtJC53+1Mej4zmnHSMq4l3V7KMsk/BxaAizoFfg/wzaq6O8nHgIur6lfP9z0WFhZqcXFx4NDd2HPwMDt3P8PxEyffGJubneET73/Hmi/eIHOnPXO/25n0fGY046RlXCnJgapaOHO8q0MoSTYCNwL3rxi+GXiwc/tB4Jau06yDXXsPnfaiARw/cZJdew+NdO4gWsjc73YmPZ8ZzThpGbvR7THwe4E7gddXjH1PVX0doPP50tUmJtmeZDHJ4tLS0iBZe3Lk2PGexoc1dxAtZO53O5Oeb9C5vTDjcEx7xm6sWeBJtgFHq+pAPxuoqvuqaqGqFubn5/v5Fn25fMNcT+PDmjuIFjL3u51Jzzfo3F6YcTimPWM3utkDvx64KcnzwO8CW5I8BPxVkssAOp+PDiXRkOzYupm52ZnTxuZmZ9ixdfNI5w6ihcz9bmfS85nRjJOWsRtrrkKpqp3AToAk7wY+WlW3JtkF3Abc3fn8B0NJNCSnThD0c/Z3kLnTnrnf7Ux6PjOacdIydqPrVShwWoFvS/K3gYeBq4AXgJ+tqm+eb/56rkKRpGlxrlUoXa8DB6iqfcC+zu3/CbxnGOEkSb3znZiS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEZ1c1HjC5M8meTpJM8muaszfnWSJ5I81bnq/LWjjytJOqWbK/K8AmypqpeTzAKPJ3kU+NfAXVX1aJL3AfcA7x5dVEnSSt1c1LiAlzt3Zzsf1fm4qDP+XcCRUQSUJK2uq2tiJpkBDgDfC/x2Ve1P8hFgb5JPsnwo5l3nmLsd2A5w1VVXDSOzJIkuT2JW1cmquhrYCFyb5O3Ah4A7qupK4A7g0+eYe19VLVTVwvz8/JBiS5J6WoVSVcdYvir9DcBtwO7OQ78HeBJTktZRN6tQ5pNs6NyeA94LfIXlY94/0fmyLcBXR5RRkrSKbo6BXwY82DkO/ibg4ap6JMkx4FNJLgD+hs5xbknS+uhmFcqXgGtWGX8c+JFRhJIkrc13YkpSoyxwSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGtXNJdUuTPJkkqeTPJvkrhWPfTjJoc74PaONKklaqZtLqr0CbKmql5PMAo8neRSYA24G3llVryS5dJRBJUmn6+aSagW83Lk72/ko4EPA3VX1Sufrjo4qpCTpbF0dA08yk+Qp4CjwWFXtB94G/HiS/Un+OMmPnmPu9iSLSRaXlpaGFlySvt11VeBVdbKqrgY2AtcmeTvLe+8XA9cBO4CHk2SVufdV1UJVLczPzw8vuSR9m+tpFUpVHQP2ATcALwK7a9mTwOvAJcMOKElaXTerUOaTbOjcngPeC3wF2ANs6Yy/DXgz8NKogkqSTtfNKpTLgAeTzLBc+A9X1SNJ3gw8kOTLwKvAbZ0TnpKkddDNKpQvAdesMv4qcOsoQkmS1uY7MSWpURa4JDXKApekRnVzErNZew4eZtfeQxw5dpzLN8yxY+tmbrnmipHPHUQLmfvdzqTnM6MZ1+N3fJiyngtHFhYWanFxcV22tefgYXbufobjJ06+MTY3O8Mn3v+ONX9Ig8yd9sz9bmfS85nRjOvxO96vJAeqauHM8ak9hLJr76HTfjgAx0+cZNfeQyOdO4gWMve7nUnPZ0Yzrsfv+LBNbYEfOXa8p/FhzR1EC5n73c6k5xt0bi/MOBwtZBy1qS3wyzfM9TQ+rLmDaCFzv9uZ9HyDzu2FGYejhYyjNrUFvmPrZuZmZ04bm5udYcfWzSOdO4gWMve7nUnPZ0Yzrsfv+LBN7SqUUyci+jnLPMjcac/c73YmPZ8ZzegqlDWs5yoUSZoW33arUCRp2lngktQoC1ySGmWBS1Kjurkiz4VJnkzydJJnk9x1xuMfTVJJvJyaJK2jbpYRvgJsqaqXk8wCjyd5tKqeSHIl8JPACyNNKUk6y5p74J2LFr/cuTvb+Ti19vA3gTtX3JckrZOujoEnmUnyFHAUeKyq9ie5CThcVU+PMqAkaXVdvROzqk4CV3euTv/7Sd4JfBz4qbXmJtkObAe46qqr+k8qSTpNT6tQquoYsA+4GXgr8HSS54GNwBeT/J1V5txXVQtVtTA/Pz9wYEnSsm5Wocx39rxJMge8FzhYVZdW1aaq2gS8CPxwVX1jlGElSd/SzSGUy4AHk8ywXPgPV9Ujo40lSVrLmgVeVV8CrlnjazYNK5AkqTu+E1OSGmWBS1KjLHBJapQFLkmNmtpLqo3TnoOHJ/5STZOecdLzgRmHxYz9s8CHbM/Bw+zc/QzHT5wE4PCx4+zc/QzARPzAYfIzTno+MOOwmHEwHkIZsl17D73xgz7l+ImT7Np7aEyJzjbpGSc9H5hxWMw4GAt8yI4cO97T+DhMesZJzwdmHBYzDsYCH7LLN8z1ND4Ok55x0vOBGYfFjIOxwIdsx9bNzM3OnDY2NzvDjq2bx5TobJOecdLzgRmHxYyD8STmkJ06qTGJZ6xPmfSMk54PzDgsZhxMqtbvYjoLCwu1uLi4btuTpGmQ5EBVLZw57iEUSWqUBS5JjbLAJalRFrgkNcoCl6RGresqlCRLwF/2Of0S4KUhxplE0/4cp/35wfQ/x2l/fjCZz/HvVtVZV4Vf1wIfRJLF1ZbRTJNpf47T/vxg+p/jtD8/aOs5eghFkhplgUtSo1oq8PvGHWAdTPtznPbnB9P/HKf9+UFDz7GZY+CSpNO1tAcuSVrBApekRjVR4EluSHIoyZ8l+di48wxTkiuT/LckzyV5Nsnt4840CklmkhxM8si4s4xCkg1JPpfkK52f5T8Yd6ZhS3JH58/ol5N8NsmF4840iCQPJDma5Msrxr47yWNJvtr5fPE4M65l4gs8yQzw28A/Bn4A+GCSHxhvqqF6DfgXVfX9wHXAL03Z8zvlduC5cYcYoU8B/7Wq/j7wQ0zZc01yBfArwEJVvR2YAT4w3lQD+4/ADWeMfQz4o6r6PuCPOvcn1sQXOHAt8GdV9edV9Srwu8DNY840NFX19ar6Yuf2/2H5F3/8/1P8ECXZCNwI3D/uLKOQ5CLgHwKfBqiqV6vq2FhDjcYFwFySC4C3AEfGnGcgVfXfgW+eMXwz8GDn9oPALeuZqVctFPgVwNdW3H+RKSu4U5JsAq4B9o85yrDdC9wJvD7mHKPy94Al4D90DhPdn+Q7xh1qmKrqMPBJ4AXg68BfV9UfjjfVSHxPVX0dlneugEvHnOe8WijwrDI2dWsfk/wt4D8DH6mq/z3uPMOSZBtwtKoOjDvLCF0A/DDw76rqGuD/MuH/9O5V51jwzcBbgcuB70hy63hTqYUCfxG4csX9jTT+T7czJZllubw/U1W7x51nyK4HbkryPMuHv7YkeWi8kYbuReDFqjr1L6fPsVzo0+S9wF9U1VJVnQB2A+8ac6ZR+KsklwF0Ph8dc57zaqHA/wT4viRvTfJmlk+cfH7MmYYmSVg+dvpcVf3bcecZtqraWVUbq2oTyz+7L1TVVO25VdU3gK8lOXWZ8vcA/2OMkUbhBeC6JG/p/Jl9D1N2orbj88Btndu3AX8wxixrmvir0lfVa0l+GdjL8pnvB6rq2THHGqbrgZ8DnknyVGfsX1bVfxlfJPXhw8BnOjsZfw78/JjzDFVV7U/yOeCLLK+cOkhDbzlfTZLPAu8GLknyIvAbwN3Aw0n+Gct/af3s+BKuzbfSS1KjWjiEIklahQUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGvX/Acmj+IjC2rUHAAAAAElFTkSuQmCC", |
127 | 154 | "text/plain": [ |
128 | 155 | "<Figure size 432x288 with 1 Axes>" |
129 | 156 | ] |
|
135 | 162 | } |
136 | 163 | ], |
137 | 164 | "source": [ |
138 | | - "plt.scatter(range(batch.shape[1]), batch[7])" |
| 165 | + "i = 0\n", |
| 166 | + "plt.scatter(batch['time'][i].cumsum(0), batch['pitch'][i])" |
139 | 167 | ] |
140 | 168 | }, |
141 | 169 | { |
142 | 170 | "cell_type": "code", |
143 | | - "execution_count": 8, |
| 171 | + "execution_count": 14, |
144 | 172 | "id": "ea23aea6-1629-4d1c-b1be-0865f8a07e9d", |
145 | 173 | "metadata": {}, |
146 | 174 | "outputs": [], |
|
239 | 267 | "for batch in tqdm(it.islice(dl,512)):\n", |
240 | 268 | " # batch = torch.LongTensor([notes for notes in it.islice(gen_tracks(batch_len), batch_size)])\n", |
241 | 269 | " opt.zero_grad()\n", |
242 | | - " r = net(batch)\n", |
| 270 | + " r = net(batch['pitch'])\n", |
243 | 271 | " nll = (-r['log_probs']).mean()\n", |
244 | 272 | " nll.backward()\n", |
245 | 273 | " opt.step()\n", |
|
407 | 435 | ], |
408 | 436 | "source": [ |
409 | 437 | "counts = torch.zeros(130,130).long()\n", |
410 | | - "for s in tqdm(ds):\n", |
| 438 | + "for item in tqdm(ds):\n", |
| 439 | + " s = item['pitch']\n", |
411 | 440 | " counts[s[:-1], s[1:]] += 1" |
412 | 441 | ] |
413 | 442 | }, |
|
564 | 593 | ], |
565 | 594 | "source": [ |
566 | 595 | "counts = torch.zeros(130,130,130).long()\n", |
567 | | - "for s in tqdm(ds):\n", |
| 596 | + "for item in tqdm(ds):\n", |
| 597 | + " s = item['pitch']\n", |
568 | 598 | " counts[s[:-2], s[1:-1], s[2:]] += 1" |
569 | 599 | ] |
570 | 600 | }, |
|
0 commit comments