Following up on the earlier post
deciphering a minimal vanilla RNN implementation, here I’d like to extend the
example to a simple LSTM model.
Once again, the idea is to combine a well-commented code sample
with some high-level diagrams and math to enable someone to
fully understand the code. The LSTM architecture presented herein is the
standard one originating from Hochreiter’s and Schmidthuber’s 1997 paper. It’s described pretty much
everywhere; Chris Olah’s post has particularly
nice diagrams and is worth reading.
LSTM cell structure
From 30,000 feet, LSTMs look just like regular RNNs; there’s a “cell” that has a
recurrent connection (output tied to input), and when trained this cell is
usually unrolled to some fixed length.
So we can take the basic RNN structure from the previous post:
LSTMs are a bit trickier because there are two recurrent connections; these
can be “packed” into a single vector h, so the above diagram still applies.
Here’s how an LSTM cell looks inside:
x is the input; p is the probabilities computed from the output y (these
symbols are named consistently with my earlier RNN post) and exit the cell at
the bottom purely due to topological convenience. The two memory vectors are h
and c – as mentioned earlier, they could be combined into a single vector, but
are shown here separately for clarity.
The main idea of LSTMs is to enable training of longer sequences by providing
a “fast-path” to back-propagate information farther down in memory. Hence the
c vector is not multiplied by any matrices on its path. The circle-in-circle
block means element-wise multiplication of two vectors; plus-in-square is
element-wise addition. The funny greek letter is the Sigmoid non-linearity:
The only other block we haven’t seen in the vanilla RNN diagram is the
colon-in-square in the bottom-left corner; this is simply the concatenation of
h and x into a single column vector. In addition, I’ve combined the
“multiply by matrix W, then add bias b” operation into a single rectantular
box to save on precious diagram space.
Here are the equations computed by a cell:
Backpropagating through an LSTM cell
This works exactly like backprop through a vanilla RNN; we have to carefully
compute how the gradient flows through every node and make sure we properly
combine gradients at fork points. Most of the elements in the LSTM diagram are
familiar from the previous post.
Let’s briefly work through the new ones.
First, the Sigmoid function; it’s an elementwise function, and computing its
derivative is very similar to the tanh function discussed in the previous
post. As usual, given , from the chain rule we have the
following derivative w.r.t. some weight w:
To compute the derivative , we’ll
use the ratio-derivative formula:
A clever way to express this is:
Going back to the chain rule with , we get:
The other new operation we’ll have to find the derivative of is element-wise
multiplication. Let’s say we have the column vectors x, y and z, each with
m rows, and we have . Since z as a function of x
has m inputs and m outputs, its Jacobian has dimensions [m,m].
is the derivative of the i-th element of z w.r.t. the j-th
element of x. For this is non-zero only
when i and j are equal, and in that case the derivative is .
Therefore, is a square matrix with the elements of y on the
diagonal and zeros elsewhere:
If we want to backprop some loss L through this function, we get:
As x has m elements, the right-hand side of this equation multiplies a [1,m]
vector by a [m,m] matrix which is diagonal, resulting in element-wise multiplication
with the matrix’s diagonal elements. In other words:
In code, it looks like this:
# Assuming dz is the gradient of loss w.r.t. z; dz, y and dx are all # column vectors. dx = dz * y
In the post about min-char-rnn,
we’ve seen that the vanilla RNN generates fairly low quality text:
one, my dred, roriny. qued bamp gond hilves non froange saws, to mold
his a work, you shirs larcs anverver strepule thunboler
muste, thum and cormed sightourd
so was rewa her besee pilman
The LSTM’s generated text quality is somewhat better when trained with roughtly
the same hyper-parameters:
the she, over is was besiving the fact to seramed for i said over he
will round, such when a where, “i went of where stood it at eye heardul rrawed
only coside the showed had off with the refaurtoned
I’m fairly sure that it can be made to perform even better with larger memory
vectors and more training data. That said, an even more advanced architecture
can be helpful here. Moreover, since this is a character-based model, to
really capture effects between words a few words apart we’ll need a much deeper
LSTM (I’m unrolling to 16 characters we can only capture 2-3 words), and hence
much more training data and time.
Once again, the goal here is not to develop a state-of-the-art language model,
but to show a simple, comprehensible example of how and LSTM is implemented
end-to-end in Python code. The full code is here
– please let me know if you find any issues with it or something still remains