Long Short-Term Memory (LSTM) A brief introduction Daniel Renshaw 24th November 2014 1 / 15
Context and notation Just to give the LSTM something to do: neural network language modelling Vocabulary, size V x t R V : true word in position t (one-hot) y t R V : predicted word in position t (distribution) Assume all sentences zero padded to length L 2 / 15
Context and notation Model: y t+1 = p (x t x t 1, x t 2,..., x 1 ) for 1 t < L Minimize cross-entropy objective: L 1 J = H (y t+1, x t ) = t=1 L 1 t=1 V x t,i log (y t+1,i ) i σ () is some sigmoid-like function (e.g. logistic or tanh) b is a bias vector, W is a weight vector 3 / 15
Multi-Layer Perceptron (MLP) y t h t e t 3 e t 3 e t 3 x t 3 x t 2 x t 1 ) y t+1 = softmax (W yh h t h t = σ (W he [e t 1 ; e t 2 ; e t 3 ] + b h) e t = W ex x t 4 / 15
Recurrent Neural Network (RNN) y t y t+1 y t+2 h t 1 h t h t+1 e t 1 e t e t+1 x t 1 x t x t+1 ) y t+1 = softmax (W yh h t h t = σ (W he e t + W hh h t 1 + b h) e t = W ex x t 5 / 15
Vanishing gradients Error gradients pass through nonlinearity every step Image from https://theclevermachine.wordpress.com Unless weights large, error signal will degrade δ h = σ () W (h+1)h δ h+1 6 / 15
Vanishing gradients Gradients may vanish or explode Can aect any 'deep' network e.g. ne-tuning a non-recurrent deep neural network Image from Alex Graves' textbook 7 / 15
Constant Error Carousel Allow the network to propagate errors without modication No nonlinearity in recursion y t y t+1 y t+2 m t 1 m t m t+1 h t 1 h t h t+1 e t 1 e t e t+1 x t 1 x t x t+1 8 / 15
Constant Error Carousel Allow the network to propagate errors without modication No nonlinearity in recursion y t y t+1 y t+2 m t 1 m t m t+1 h t 1 h t h t+1 e t 1 e t e t+1 x t 1 x t x t+1 8 / 15
Constant Error Carousel Allow the network to propagate errors without modication No nonlinearity in recursion m t 1 m t h t 1 h t e t b h dense matrix multiplication 8 / 15
Constant Error Carousel Allow the network to propagate errors without modication No nonlinearity in recursion y t+1 = softmax (W ym m t ) m t = σ (h t ) h t = h t 1 +σ (W he e t + W hm m t 1 + b h) e t = W ex x t 8 / 15
LSTM v1: input and output gates Attenuate input and output signals m t 1 b o logistic o t m t h t 1 h t b i logistic i t e t b h 9 / 15
LSTM v1: input and output gates Attenuate input and output signals y t+1 = softmax (W ym m t ) m t = o t σ (h t ) o t = logistic (W oe e t + W om m t 1 + b o ) h t = h t 1 + i t σ (W he e t + W hm m t 1 + b h) i t = logistic ( W ie e t + W im m t 1 + b i) e t = W ex x t 9 / 15
LSTM v2: forget (remember) gate Model controls when memory, h t, is reduced Forget gate should be called remember gate m t 1 b o b f logistic logistic o t f t m t h t 1 h t b i logistic i t e t b h 10 / 15
LSTM v2: forget (remember) gate Model controls when memory, h t, is reduced Forget gate should be called remember gate y t+1 = softmax (W ym m t ) m t = o t σ (h t ) o t = logistic (W oe e t + W om m t 1 + b o ) h t = f i h t 1 + i t σ (W he e t + W hm m t 1 + b h) i t = logistic ( W ie e t + W im m t 1 + b i) f i = logistic (W ) fe e t + W fm m t 1 + b f e t = W ex x t 10 / 15
LSTM v3: peepholes Allow the gates to additionally see the internal memory state Diagonal matrices only (all others dense) m t 1 b o b f logistic logistic o t f t m t h t 1 h t b i logistic i t e t b h diagonal matrix multiplication 11 / 15
LSTM v3: peepholes Allow the gates to additionally see the internal memory state Diagonal matrices only (all others dense) y t+1 = softmax (W ym m t ) m t = o t σ (h t ) o t = logistic (W oe e t + W om m t 1 +W oh h t + b o) h t = f i h t 1 + i t σ (W he e t + W hm m t 1 + b h) i t = logistic (W ie e t + W im m t 1 +W ih h t 1 + b i) f i = logistic (W ) fe e t + W fm m t 1 +W fh h t 1 + b f e t = W ex x t 11 / 15
LSTM v4: output projection layer Reduces dimensionality of recursive messages Can speed up training without aecting results quality m t 1 b o b f logistic logistic o t f t m t h t 1 h t b i logistic i t e t b h 12 / 15
LSTM v4: output projection layer Reduces dimensionality of recursive messages Can speed up training without aecting results quality y t+1 = softmax (W ym m t ) m t = W mm (o t σ (h t )) o t = logistic (W oe e t + W om m t 1 + W oh h t + b o) h t = f i h t 1 + i t σ (W he e t + W hm m t 1 + b h) i t = logistic (W ie e t + W im m t 1 + W ih h t 1 + b i) f i = logistic (W ) fe e t + W fm m t 1 + W fh h t 1 + b f e t = W ex x t 12 / 15
Gradients no longer vanish Image from Alex Graves' textbook 13 / 15
LSTM implementations RNNLIB (Alex Graves) http://sourceforge.net/p/rnnl/ PyLearn2 (experimental code, in sandbox/rnn/models/rnn.py) Theano, e.g. def lstm_step(x_t, m_tm1, h_tm1, w_xe,..., b_o): e_t = dot(x_t, w_xe) i_t = sigmoid(dot(e_t, w_ei) + dot(m_tm1, w_mi) + c_tm1 * w_ci + b_i) f_t = sigmoid(dot(e_t, w_ef) + dot(m_tm1, w_mf) + c_tm1 * w_cf + b_f) h_t = f_t * h_tm1 + i_t * tanh(dot(e_t, w_eh) + dot(m_tm1, w_mh) + b_h) o_t = sigmoid(dot(e_t, w_eo) + dot(m_tm1, w_mo) + c_t * w_co + b_o) m_t = dot(o_t * tanh(h_t), w_mm) y_t = softmax(dot(m_t, w_my)) return m_t, c_t, y_t 14 / 15
Further thoughts Sequences vs. hierarchies vs. plain 'deep' Other solutions to vanishing gradients Clockwork RNN Dierent training algorithms (e.g. Hessian Free optimization) Rectied linear units (ReLU)? σ (x) = max (0, x); constant gradient when active 15 / 15