-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathrnn.qmd
More file actions
398 lines (274 loc) · 19.8 KB
/
rnn.qmd
File metadata and controls
398 lines (274 loc) · 19.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
# Recurrent Neural Networks {#sec-rnn}
So far, we have limited our attention to domains in which each output \(y\) is assumed to have been generated as a function of an associated input \(x\), and our hypotheses have been “pure” functions, in which the output depends only on the input (and the parameters we have learned that govern the function’s behavior). In the next few chapters, we are going to consider cases in which our models need to go beyond functions. In particular, behavior as a function of *time* will be an important concept:
- In *recurrent neural networks* (recurrent neural network), the hypothesis that we learn is not a function of a single input, but of the whole sequence of inputs that the predictor has received.
- In *reinforcement learning* (reinforcement learning), the hypothesis is either a *model* of a domain (such as a game) as a recurrent system or a *policy* which is a pure function, but whose loss is determined by the ways in which the policy interacts with the domain over time.
In this chapter, we introduce *state machines*. We start with deterministic state machines, and then consider recurrent neural network (RNN) architectures to model their behavior. Later, in Chapter [MDPs](#chap-mdps) we will study *Markov decision processes* (MDPs) that extend to consider probabilistic (rather than deterministic) transitions in our state machines. RNNs and MDPs will enable description and modeling of temporally sequential patterns of behavior that are important in many domains.
## State machines {#sec-state_machines}
A *state machine*
::: {.column-margin}
This is such a pervasive idea that it has been given many names in many subareas of computer science, control theory, physics, etc., including: *automaton*, *transducer*, *dynamical system*, etc.
:::
is a description of a process (computational, physical, economic) in terms of its potential sequences of *states*.
The *state* of a system is defined to be all you would need to know about the system to predict its future trajectories as well as possible. It could be the position and velocity of an object, the locations of your pieces on a game board, or the current traffic densities on a highway network.
Formally, we define a *state machine* as
\[
(\mathcal{S}, \mathcal{X}, \mathcal{Y}, s_0, f_s, f_o)
\]
where
- \(\mathcal{S}\) is a finite or infinite set of possible states;
- \(\mathcal{X}\) is a finite or infinite set of possible inputs;
- \(\mathcal{Y}\) is a finite or infinite set of possible outputs;
- \(s_0 \in \mathcal{S}\) is the initial state of the machine;
- \(f_s: \mathcal{S} \times \mathcal{X} \rightarrow \mathcal{S}\) is a *transition function* (which takes an input and a previous state and produces a next state);
- \(f_o: \mathcal{S} \rightarrow \mathcal{Y}\) is an *output function* (which takes a state and produces an output).
The basic operation of the state machine is to
::: {.column-margin}
In some cases, we will pick a starting state from a set or distribution.
:::
start with state \(s_0\), then iteratively compute for \(t \ge 1\):
\[
s_t = f_s(s_{t-1}, x_t)
\]
\[
y_t = f_o(s_t)
\]
::: {.example}
The diagram below illustrates this process. Note that the “feedback” connection of \(s_t\) back into \(f_s\) has to be buffered or delayed by one time step—otherwise what it is computing would not generally be well defined.
*Diagram omitted or replaced with an appropriate image/tikz figure.*
:::
So, given a sequence of inputs \(x_1, x_2, \dots\) the machine generates a sequence of outputs
\[
\underbrace{f_o(f_s(s_0, x_1))}_{y_1}, \underbrace{f_o(f_s(f_s(s_0, x_1), x_2))}_{y_2}, \dots
\]
We sometimes say that the machine *transduces* sequence \(x\) into sequence \(y\). The output at time \(t\) can depend on inputs from steps 1 to \(t\).
::: {.column-margin}
There are a huge number of major and minor variations on the idea of a state machine. We'll just work with one specific one in this section and another one in the next, but don't worry if you see other variations out in the world!
:::
One common form is *finite state machines*, in which \(\mathcal{S}\), \(\mathcal{X}\), and \(\mathcal{Y}\) are all finite sets. They are often described using *state transition diagrams* such as the one below, in which nodes stand for states and arcs indicate transitions. Nodes are labeled by which output they generate and arcs are labeled by which input causes the transition.
::: {.example}
One can verify that the state machine below reads binary strings and determines the parity of the number of zeros in the given string. Check for yourself that all input binary strings end in state \(S_1\) if and only if they contain an even number of zeros.
{scale=1.0}
:::
Another common structure that is simple but powerful and used in signal processing and control is *linear time-invariant (LTI) systems*. In this case, all the quantities are real-valued vectors: \(\mathcal{S} = \mathbb{R}^m\), \(\mathcal{X} = \mathbb{R}^\ell\) and \(\mathcal{Y} = \mathbb{R}^n\). The functions \(f_s\) and \(f_o\) are linear functions of their inputs. The transition function is described by the state matrix \(A\) and the input matrix \(B\); the output function is defined by the output matrix \(C\), each with compatible dimensions. In discrete time, they can be defined by a linear difference equation, like
\[
s_t = f_s(s_{t-1}, x_t) = A s_{t-1} + B x_t,
\]
\[
y_t = f_o(s_t) = C s_t,
\]
and can be implemented using state to store relevant previous input and output information.
We will study *recurrent neural networks* which are a lot like a non-linear version of an LTI system.
## Recurrent neural networks {#sec-rnn_model}
In Chapter [Neural Networks](#chap-neural_networks), we studied how a neural network’s weights can be obtained by training on data so that the network models a function approximating the relationship between \((x, y)\) pairs in a supervised-learning training set. In Section [State machines](#sec-state_machines) above, we introduced state machines to describe sequential temporal behavior. Here, we explore recurrent neural networks by defining the architecture and weight matrices in a neural network to enable modeling of such state machines. Then, in Section [Sequence-to-sequence RNN](#sec-seq2seq_rnn), we present a loss function that may be employed for training sequence-to-sequence RNNs, and in Section [RNN as a language model](#sec-language) we consider applications to language translation and recognition. Finally, in Section [Back-propagation through time](#sec-bptt), we see how to use gradient-descent methods to train the weights of an RNN so that it performs a *transduction* matching as closely as possible a training set of input–output *sequences*.
A *recurrent neural network* is a state machine in which neural networks constitute the functions \(f_s\) and \(f_o\):
\[
s_t = f_s\Bigl(W^{sx}x_t + W^{ss}s_{t-1} + W^{ss}_0\Bigr)
\]
\[
y_t = f_o\Bigl(W^o s_t + W^o_0\Bigr).
\]
The inputs, states, and outputs are all vector-valued:
\[
x_t: \ell \times 1,\quad s_t: m \times 1,\quad y_t: v \times 1.
\]
The weights in the network are
\[
W^{sx}: m \times \ell,\quad W^{ss}: m \times m,\quad W^{ss}_0: m \times 1,
\]
\[
W^o: v \times m,\quad W^o_0: v \times 1,
\]
with activation functions \(f_s\) and \(f_o\).
::: {.study_question_callout}
Check dimensions here to be sure it all works out. Remember that we apply \(f_s\) and \(f_o\) elementwise, unless \(f_o\) is a softmax activation.
:::
## Sequence-to-sequence RNN {#sec-seq2seq_rnn}
Now, how can we set up an RNN to model and be trained to produce a transduction of one sequence to another? This problem is sometimes called *sequence-to-sequence mapping*. You can think of it as a kind of regression problem: given an input sequence, learn to generate the corresponding output sequence.
::: {.column-margin}
One way to think of training a sequence **classifier** is to reduce it to a transduction problem, where \(y_t = 1\) if the sequence \(x_1, \ldots, x_t\) is a *positive* example of the class of sequences and \(-1\) otherwise.
:::
A training set has the form
\[
\left[\left(x^{(1)}, y^{(1)}\right), \dots, \left(x^{(q)}, y^{(q)}\right)\right],
\]
where
- \(x^{(i)}\) and \(y^{(i)}\) are length \(n^{(i)}\) sequences;
- sequences in the *same pair* are the same length; and
- sequences in different pairs may have different lengths.
Next, we need a loss function. We start by defining a loss function on sequences. There are many possible choices, but usually it makes sense to sum a per-element loss function on each of the output values, where \(g\) is the predicted sequence and \(y\) is the actual one:
\[
\mathcal{L}_{\text{seq}}\left(g^{(i)}, y^{(i)}\right) = \sum_{t=1}^{n^{(i)}} \mathcal{L}_\text{elt}\left(g_t^{(i)}, y_t^{(i)}\right).
\]
The per-element loss function \(\mathcal{L}_\text{elt}\)
::: {.column-margin}
So it could be NLL, squared loss, etc.
:::
will depend on the type of \(y_t\) and what information it is encoding, just as in a supervised network.
Then, letting
\[
W = \Bigl(W^{sx}, W^{ss}, W^o, W^{ss}_0, W^o_0\Bigr),
\]
our overall goal is to minimize the objective
\[
J(W) = \frac{1}{q} \sum_{i=1}^{q} \mathcal{L}_{\text{seq}}\Bigl(\text{RNN}(x^{(i)};W), y^{(i)}\Bigr),
\]
where \(\text{RNN}(x;W)\) is the output sequence generated given input sequence \(x\).
It is typical to choose \(f_s\) to be *tanh*
::: {.column-margin}
Remember that it looks like a sigmoid but ranges from \(-1\) to \(+1\).
:::
but any non-linear activation function is usable. We choose \(f_o\) to align with the types of our outputs and the loss function, just as in regular supervised learning.
## RNN as a language model {#sec-language}
A *language model* is a sequence-to-sequence RNN which is trained on a token sequence \(c = (c_1, c_2, \ldots, c_k)\) and is used to predict the next token \(c_t\) (for \(t \le k\)) given the previous \(t-1\) tokens:
\[
c_t = \text{RNN}\Bigl((c_1, c_2, \dots, c_{t-1}); W\Bigr).
\]
We can convert this to a sequence-to-sequence training problem by constructing a dataset of \(q\) different \((x,y)\) sequence pairs, where we introduce special tokens \(\langle\text{start}\rangle\) and \(\langle\text{end}\rangle\) to signal the beginning and end of the sequence:
\[
x = \bigl(\langle\text{start}\rangle, c_1, c_2, \ldots, c_k\bigr),
\]
\[
y = \bigl(c_1, c_2, \dots, \langle\text{end}\rangle\bigr).
\]
## Back-propagation through time {#sec-bptt}
Now the fun begins! We can now try to find a \(W\) to minimize \(J\) using gradient descent. We will work through the simplest method, *back-propagation through time* (BPTT), in detail. This method is generally not the best to use, but it is relatively easy to understand. In Section [LSTM](#lstm) we will sketch alternative methods that are more common in practice.
::: {.notice}
What we want you to take away from this section is that by “unrolling” a recurrent network to model a particular sequence, we can treat the whole system as a feed-forward network with extensive parameter sharing. Thus, we can tune the parameters using stochastic gradient descent and learn to model sequential mappings. While the details are important to get right if you need to implement something, the mathematical details below are presented primarily to convey the larger concepts.
:::
::: {.example}
**Calculus reminder: total derivative**
Most of us are not very careful about the difference between the *partial derivative* and the *total derivative*. Here’s an example from the Wikipedia article on partial derivatives.
The volume of a circular cone depends on its height and radius:
\[
V(r, h) = \frac{\pi r^2 h}{3}.
\]
The partial derivatives of the volume with respect to the radius and height are
\[
\frac{\partial V}{\partial r} = \frac{2\pi r h}{3}\quad\text{and}\quad \frac{\partial V}{\partial h} = \frac{\pi r^2}{3}.
\]
These measure the change in \(V\) assuming that all other variables are held constant. Now assume that we want to preserve the cone’s proportions so that the ratio of radius to height remains constant. Then we cannot change one without changing the other, and we must consider the *total derivative*. For example, the total derivative with respect to \(r\) is given by
\[
\frac{dV}{dr} = \frac{\partial V}{\partial r} + \frac{\partial V}{\partial h}\frac{dh}{dr} = \frac{2 \pi r h}{3} + \frac{\pi r^2}{3}\frac{dh}{dr}.
\]
Similarly, the total derivative with respect to \(h\) is
\[
\frac{dV}{dh} = \frac{\partial V}{\partial h} + \frac{\partial V}{\partial r}\frac{dr}{dh} = \frac{\pi r^2}{3} + \frac{2 \pi r h}{3}\frac{dr}{dh}.
\]
For a concrete example, consider a right circular cone with a fixed angle \(\alpha = \tan(r/h)\) so that \(r = h\,\tan^{-1}\alpha\). Let \(c = \tan^{-1}\alpha\), then \(r = c\,h\) and we have
\[
\frac{dV}{dr} = \frac{2 \pi r h}{3} + \frac{\pi r^2}{3}\frac{1}{c},\quad
\frac{dV}{dh} = \frac{\pi r^2}{3} + \frac{2 \pi r h}{3}\,c.
\]
:::
The **BPTT** process goes as follows:
1. **Sampling:** Sample a training pair of sequences \((x,y)\); let their length be \(n\).
2. **Unrolling:** “Unroll” the RNN to length \(n\) (for example, see the diagram below for \(n=3\)), and initialize \(s_0\):
{width=\textwidth}
Now the problem resembles performing back-propagation on a feed-forward network—except that the weight matrices are shared across time. (This is similar in spirit to convolutional networks, where weights are reused spatially.)
3. **Forward Pass:** Compute the predicted output sequence \(g\) via
\[
z_t^1 = W^{sx}x_t + W^{ss}s_{t-1} + W^{ss}_0,
\]
\[
s_t = f_s(z_t^1),
\]
\[
z_t^2 = W^o s_t + W^o_0,
\]
\[
g_t = f_o(z_t^2).
\]
4. **Backward Pass:** Compute the gradients. For both \(W^{ss}\) and \(W^{sx}\), we need
\[
\frac{d \mathcal{L}_\text{seq}(g,y)}{d W} = \sum_{u=1}^n \frac{d \mathcal{L}_\text{elt}(g_u,y_u)}{d W}.
\]
Let \(\mathcal{L}_u = \mathcal{L}_\text{elt}(g_u,y_u)\). Using the total derivative (summing over all the ways \(W\) affects \(\mathcal{L}_u\)), we have
\[
\frac{d \mathcal{L}_\text{seq}}{d W} = \sum_{t=1}^n \frac{\partial s_t}{\partial W} \left(\frac{\partial \mathcal{L}_t}{\partial s_t} + \underbrace{\sum_{u=t+1}^n \frac{\partial \mathcal{L}_u}{\partial s_t}}_{\delta^{s_t}}\right).
\]
Here, \(\delta^{s_t}\) represents the impact of state \(s_t\) on all future losses. Define the *future loss* after step \(t\) as
\[
F_t = \sum_{u=t+1}^{n} \mathcal{L}_\text{elt}(g_u,y_u)
\]
so that
\[
\delta^{s_t} = \frac{\partial F_t}{\partial s_t}.
\]
Note that \(F_n = 0\) (hence \(\delta^{s_n} = 0\)).
Working backwards, for each \(t\) we have
\[
\delta^{s_{t-1}} = \frac{\partial s_t}{\partial s_{t-1}} \left[\frac{\partial \mathcal{L}_\text{elt}(g_t,y_t)}{\partial s_t} + \delta^{s_t}\right].
\]
Using the chain rule, we write
\[
\frac{\partial \mathcal{L}_\text{elt}(g_t,y_t)}{\partial s_t} = \frac{\partial z_t^2}{\partial s_t}\,\frac{\partial \mathcal{L}_\text{elt}(g_t,y_t)}{\partial z_t^2},
\]
and
\[
\frac{\partial s_t}{\partial s_{t-1}} = \frac{\partial z_t^1}{\partial s_{t-1}}\,\frac{\partial s_t}{\partial z_t^1} = {W^{ss}}^T\,\frac{\partial s_t}{\partial z_t^1}.
\]
::: {.example}
Note that \(\frac{\partial s_t}{\partial z_t^1}\) is formally an \(m \times m\) diagonal matrix whose diagonal entries are \(f_s'(z_{t,i}^1)\) for \(1\le i\le m\). Often, we represent this diagonal matrix as an \(m \times 1\) vector \(f_s'(z_t^1)\). In that case, the product \({W^{ss}}^T * f_s'(z_t^1)\) should be interpreted as multiplying each column of \({W^{ss}}^T\) by the corresponding entry of \(f_s'(z_t^1)\).
:::
Putting everything together, we obtain
\[
\delta^{s_{t-1}} = {W^{ss}}^T\,\frac{\partial s_t}{\partial z_t^1}\,\Bigl({W^o}^T\,\frac{\partial \mathcal{L}_t}{\partial z_t^2} + \delta^{s_t}\Bigr).
\]
The gradients for the weight matrices are then given by
\[
\frac{d \mathcal{L}_\text{seq}}{d W^{ss}} = \sum_{t=1}^n \frac{\partial z_t^1}{\partial W^{ss}}\,\frac{\partial s_t}{\partial z_t^1}\,\frac{\partial F_{t-1}}{\partial s_t},
\]
\[
\frac{d \mathcal{L}_\text{seq}}{d W^{sx}} = \sum_{t=1}^n \frac{\partial z_t^1}{\partial W^{sx}}\,\frac{\partial s_t}{\partial z_t^1}\,\frac{\partial F_{t-1}}{\partial s_t}.
\]
The weight \(W^o\) is simpler because it does not affect future losses:
\[
\frac{d \mathcal{L}_\text{seq}}{d W^o} = \sum_{t=1}^n \frac{\partial \mathcal{L}_t}{\partial z_t^2}\,\frac{\partial z_t^2}{\partial W^o}.
\]
Assuming \(\frac{\partial \mathcal{L}_t}{\partial z_t^2} = (g_t - y_t)\) (which holds for squared loss, softmax-NLL, etc.), then
\[
\frac{d \mathcal{L}_\text{seq}}{d W^o} = \sum_{t=1}^n (g_t - y_t)\, s_t^T.
\]
Whew!
5. *(End of BPTT steps)*
:::
::: {.study_question_callout}
Derive the updates for the offsets \(W^{ss}_0\) and \(W^o_0\).
:::
## Vanishing gradients and gating mechanisms {#sec-rnn_lstm #lstm}
Let's take a careful look at the backward propagation of the gradient along the sequence:
\[
\delta^{s_{t-1}} = \frac{\partial s_t}{\partial s_{t-1}} \left[\frac{\partial \mathcal{L}_\text{elt}(g_t,y_t)}{\partial s_t} + \delta^{s_t}\right].
\]
Consider a case where only the output at the end of the sequence is incorrect, but it depends critically (via the weights) on the input at time 1. In this case, we will multiply the loss at step \(n\) by
\[
\frac{\partial s_2}{\partial s_1}\,\frac{\partial s_3}{\partial s_2}\,\cdots\,\frac{\partial s_n}{\partial s_{n-1}}.
\]
In general, this product will either grow or shrink exponentially with the length of the sequence, making training very difficult.
::: {.study_question_callout}
The last time we talked about exploding and vanishing gradients, it was to justify per-weight adaptive step sizes. Why is that not a solution to the problem this time?
:::
An important insight that enabled recurrent networks to work well on long sequences is the idea of *gating*.
### Simple gated recurrent networks
A computer only ever updates some parts of its memory on each computation cycle. We can take this idea and design networks that are better able to retain state values over time and exhibit better-behaved gradients. We introduce a new component called a *gating network*. Let \(g_t\) be an \(m \times 1\) vector and let \(W^{gx}\) and \(W^{gs}\) be weight matrices of dimensions \(m \times \ell\) and \(m \times m\), respectively. We compute \(g_t\) as
::: {.column-margin}
It can have an offset, too, but we are omitting it for simplicity.
:::
\[
g_t = \text{sigmoid}\Bigl(W^{gx}x_t + W^{gs}s_{t-1}\Bigr).
\]
Then we modify the computation of \(s_t\) as
\[
s_t = (1 - g_t) \ast s_{t-1} + g_t \ast f_s\Bigl(W^{sx}x_t + W^{ss}s_{t-1} + W^{ss}_0\Bigr),
\]
where \(\ast\) denotes componentwise multiplication. Here, the gating network’s output determines, for each state dimension, how much it should be updated. This mechanism helps the network learn to “store” information in certain dimensions and then refrain from changing it unnecessarily.
::: {.study_question_callout}
Why is it important that the activation function for \(g\) be a sigmoid?
:::
### Long short-term memory
The idea of gating networks can be extended to design a state machine even more like computer memory, resulting in a network called a *long short-term memory* (LSTM).
::: {.column-margin}
Yet another awesome name for a neural network!
:::
We won’t go into all the details here, but the basic idea is that there is a memory cell (i.e. our state vector) along with three gating networks. The *input* gate selects—which via a “soft” selection mechanism— which dimensions of the state to update with new values; the *forget* gate decides which dimensions should have their old values decayed toward zero; and the *output* gate determines which dimensions are used to compute the network’s output. LSTMs have achieved remarkable results in applications such as language translation. A diagram of the LSTM architecture is shown below:
{scale=0.7}