Every network we’ve built so far has no memory. Feed it the word “bank” and it gives you an answer without knowing whether the previous words were “river” or “savings”. Each input is processed in isolation.
Real sequences don’t work that way. The meaning of now depends on before. Language, music, time series, video — all of these have temporal structure that a feedforward network is blind to.
Recurrent neural networks solve this with one simple idea: give the network a hidden state that persists between time steps.
The Vanilla RNN #
At each time step $t$, the RNN takes the current input $x_t$ and the previous hidden state $h_{t-1}$, and produces a new hidden state $h_t$:
$$h_t = \tanh(\mathbf{W}_x x_t + \mathbf{W}h h{t-1} + \mathbf{b}_h)$$
$$\hat{y}_t = \mathbf{W}_y h_t + \mathbf{b}_y$$
The hidden state $h_t$ is the network’s memory — a compressed summary of everything it has seen up to time $t$. It gets passed back into the network at the next step.
The same weights $\mathbf{W}_x$, $\mathbf{W}_h$, $\mathbf{W}_y$ are used at every time step. This is weight sharing through time — the same transformation is applied, just with a different history in the hidden state.
To process a sequence of length $T$, you “unroll” the network $T$ times — the same weights applied repeatedly, feeding the hidden state forward.
Backprop Through Time #
Training an RNN requires computing gradients through the unrolled network — Backpropagation Through Time (BPTT). The loss at step $T$ depends on the hidden state at $T$, which depends on hidden state at $T-1$, which depends on… all the way back to step 1.
$$\frac{\partial L}{\partial \mathbf{W}h} = \sum{t=1}^{T} \frac{\partial L_t}{\partial \mathbf{W}h} = \sum{t=1}^{T} \frac{\partial L_t}{\partial h_t} \cdot \prod_{k=t}^{T-1} \frac{\partial h_{k+1}}{\partial h_k} \cdot \frac{\partial h_t}{\partial \mathbf{W}_h}$$
That product of Jacobians is the problem. With tanh activations and many time steps, each $\frac{\partial h_{k+1}}{\partial h_k}$ has magnitude less than 1. Multiply 50 of them together and the gradient vanishes. Multiply 50 values greater than 1 and it explodes.
Gradient clipping handles the exploding case — cap the gradient norm at some threshold:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
The vanishing case requires a better architecture.
Truncated BPTT is the practical solution for long sequences: only backpropagate through the last $k$ steps instead of the full sequence. You lose long-range gradients but training becomes stable.
The LSTM #
Long Short-Term Memory networks, introduced in 1997 by Hochreiter and Schmidhuber. The idea: give the network an explicit cell state $c_t$ — a memory highway that gradients can flow through with minimal decay.
Three gates control what happens to the cell state:
Forget gate — decides what to erase from memory:
$$f_t = \sigma(\mathbf{W}f [h{t-1}, x_t] + \mathbf{b}_f)$$
Input gate — decides what new information to write:
$$i_t = \sigma(\mathbf{W}i [h{t-1}, x_t] + \mathbf{b}_i)$$ $$\tilde{c}_t = \tanh(\mathbf{W}c [h{t-1}, x_t] + \mathbf{b}_c)$$
Cell state update — erase the old, write the new:
$$c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t$$
Output gate — decides what to expose as the hidden state:
$$o_t = \sigma(\mathbf{W}o [h{t-1}, x_t] + \mathbf{b}_o)$$ $$h_t = o_t \odot \tanh(c_t)$$
The cell state $c_t$ is updated by addition — $c_t = f_t \odot c_{t-1} + \ldots$ When the forget gate is near 1 and the input gate is near 0, the cell state passes through almost unchanged. The gradient of the cell state update is:
$$\frac{\partial c_t}{\partial c_{t-1}} = f_t$$
A gradient flowing backward through the cell state is multiplied by $f_t$ at each step — not a product of Jacobians. If the forget gate is near 1, the gradient passes through cleanly. The LSTM can learn to remember things for hundreds of steps.
GRU: The Simpler Sibling #
Gated Recurrent Units (2014) merge the cell state and hidden state, reducing the four gates to two:
$$z_t = \sigma(\mathbf{W}z [h{t-1}, x_t])$$
$$r_t = \sigma(\mathbf{W}r [h{t-1}, x_t])$$
$$\tilde{h}t = \tanh(\mathbf{W} [r_t \odot h{t-1}, x_t])$$
$$h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t$$
Fewer parameters, often similar performance. GRUs are faster to train and work well for many sequence tasks. In practice: if you don’t know which to use, try GRU first.
In Python #
import torch
import torch.nn as nn
class CharRNN(nn.Module):
def __init__(self, vocab_size, hidden_size, num_layers=2):
super().__init__()
self.embed = nn.Embedding(vocab_size, hidden_size)
self.lstm = nn.LSTM(hidden_size, hidden_size,
num_layers=num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, x, hidden=None):
x = self.embed(x)
out, hidden = self.lstm(x, hidden)
return self.fc(out), hidden
model = CharRNN(vocab_size=65, hidden_size=256, num_layers=2)
# Training loop
for chars, targets in dataloader:
logits, hidden = model(chars)
loss = F.cross_entropy(logits.view(-1, 65), targets.view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step()
Demo: Sequence Prophet #
A vanilla RNN trained from scratch on a repeating pattern. Hit Train and watch it learn to predict the next value in the sequence. Then switch patterns and retrain — the sawtooth and square wave are much harder than the sine wave. The animated hidden state bars show what the RNN is “thinking” at each step.
Notice: after training, the RNN tracks the sine wave well — it learns the period and amplitude. The sawtooth is harder — sharp resets confuse it at first. The square wave is hardest — the RNN struggles with sudden transitions because its tanh activations are smooth and can’t represent instant jumps. This is exactly why LSTMs with their gated cell state outperform vanilla RNNs on real sequences.
The hidden state bars show the RNN’s internal memory — different units activate for different phases of the sequence.
What RNNs Are Used For (or Were) #
Before Transformers took over in 2017:
- Language modelling — predict the next character/word (Karpathy’s famous char-rnn)
- Machine translation — encode source sentence into hidden state, decode to target
- Speech recognition — audio frames → text
- Time series — stock prices, sensor data, weather
Today, Transformers have replaced RNNs for most NLP tasks. But RNNs are still used for:
- Real-time sequential tasks where you can’t wait for the full sequence
- Embedded/low-memory devices
- Certain scientific time series problems
Understanding RNNs is essential for understanding why attention and Transformers were invented — they fix the exact problems (vanishing gradients, sequential bottleneck) that make RNNs hard.
Before You Go — Try These #
-
In the demo, train on Sine and let it run. Now add noise 0.3 and watch the predictions. Does the RNN follow the noise, or does it smooth over it? What does that tell you about what the RNN learned?
-
A vanilla RNN with hidden size 8, processing a scalar input, has how many parameters total? Count $\mathbf{W}_x$, $\mathbf{W}_h$, $\mathbf{b}_h$, $\mathbf{W}_y$, $b_y$.
-
Why does the Square wave pattern cause problems for a tanh RNN? What property of tanh makes it hard to represent sudden transitions?
-
The LSTM cell state update is $c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t$. If the forget gate $f_t \approx 1$ and input gate $i_t \approx 0$, what happens to the cell state? What has the LSTM “decided” to do?
-
Gradient clipping in BPTT caps the gradient norm. If the gradient vector is $[10, 5, 8, 6]$ and the max norm is 5, what is the scaled gradient vector?
Next up → Lesson 12: Pay Attention — self-attention, positional encoding, and how Transformers replaced everything.