Expectation maximization tutorial Octavian Ganea November 18, 2016 1/1
Today Expectation - maximization algorithm Topic modelling 2/1
ML & MAP Observed data: X = {x 1, x 2... x N } 3/1
ML & MAP Observed data: X = {x 1, x 2... x N } Probabilistic model of the data: p(x θ) = n i=1 p(x i θ) 3/1
ML & MAP Observed data: X = {x 1, x 2... x N } Probabilistic model of the data: p(x θ) = n i=1 p(x i θ) Estimate parameters: 3/1
ML & MAP Observed data: X = {x 1, x 2... x N } Probabilistic model of the data: p(x θ) = n i=1 p(x i θ) Estimate parameters: Maximum likelihood: ˆθ ML = arg max θ p(x θ) 3/1
ML & MAP Observed data: X = {x 1, x 2... x N } Probabilistic model of the data: p(x θ) = n i=1 p(x i θ) Estimate parameters: Maximum likelihood: ˆθ ML = arg max θ p(x θ) Maximum a-posteriori: ˆθ MAP = arg max θ p(θ X ) = arg max θ [p(θ) + p(x θ)] 3/1
Maximizing the log-likelihood Observed data: X = {x 1, x 2... x N } 4/1
Maximizing the log-likelihood Observed data: X = {x 1, x 2... x N } Log-likelihood: l(θ) = log p(x θ) = N i=1 log p(x i θ) 4/1
Maximizing the log-likelihood Observed data: X = {x 1, x 2... x N } Log-likelihood: l(θ) = log p(x θ) = N i=1 log p(x i θ) Latent variables: log p(x θ) = log Z p(x, Z θ) 4/1
Maximizing the log-likelihood Observed data: X = {x 1, x 2... x N } Log-likelihood: l(θ) = log p(x θ) = N i=1 log p(x i θ) Latent variables: log p(x θ) = log Z p(x, Z θ) Hard to maximize l(θ) directly (no closed form solution in most of the interesting cases). 4/1
Maximizing the log-likelihood Observed data: X = {x 1, x 2... x N } Log-likelihood: l(θ) = log p(x θ) = N i=1 log p(x i θ) Latent variables: log p(x θ) = log Z p(x, Z θ) Hard to maximize l(θ) directly (no closed form solution in most of the interesting cases). One solution: 4/1
Maximizing the log-likelihood Observed data: X = {x 1, x 2... x N } Log-likelihood: l(θ) = log p(x θ) = N i=1 log p(x i θ) Latent variables: log p(x θ) = log Z p(x, Z θ) Hard to maximize l(θ) directly (no closed form solution in most of the interesting cases). One solution: use a gradient method (e.g. gradient ascent, Newton) 4/1
Maximizing the log-likelihood Observed data: X = {x 1, x 2... x N } Log-likelihood: l(θ) = log p(x θ) = N i=1 log p(x i θ) Latent variables: log p(x θ) = log Z p(x, Z θ) Hard to maximize l(θ) directly (no closed form solution in most of the interesting cases). One solution: use a gradient method (e.g. gradient ascent, Newton) sometimes the gradient is hard to compute, hard to implement, or we do not want a black-box optimization routine with no guarantees 4/1
Expectation - maximization algorithm Used in models with latent variables. Iterative algorithm that guarantees convergence to stationary point of l(θ) (i.e. point with gradient zero, either local optimum or saddle point). No global optima guarantees. EM reaches either a local maximum or a saddle point Convergence speed might be slow. Idea: Builds sequence: l(θ (0) ) l(θ (1) )... l(θ (t) )... 5/1
Expectation - maximization algorithm Used in models with latent variables. Iterative algorithm that guarantees convergence to stationary point of l(θ) (i.e. point with gradient zero, either local optimum or saddle point). No global optima guarantees. EM reaches either a local maximum or a saddle point Convergence speed might be slow. Idea: Builds sequence: l(θ (0) ) l(θ (1) )... l(θ (t) )... At each step, using Jensen s inequality, finds a lower bound g s.t. l(θ (t) ) g(θ (t+1), q) l(θ (t+1) ) 5/1
Expectation - maximization algorithm For any probability distribution q(z) (s.t. Z q(z) = 1), Jensen inequality gives a lower bound F (q, θ) on the true likelihood: ( ) ( ) p(x, Z θ) l(θ) = log p(x, Z θ) = log q(z) q(z) Z Z Reason: log( ) is concave. ( ), Z θ) q(z) log p(x q(z) Equality case: q(z) = p(z X, θ). Z := F (q, θ) 6/1
Expectation - maximization algorithm Update rule: where g t (θ) := F (p(z X, θ (t) ), θ) = Z θ (t+1) = arg max g t (θ) θ ( ) p(x, Z θ) p(z X, θ (t) ) log p(z X, θ (t) ) From above, g t (θ) l(θ), θ in particular: gt (θ (t+1) ) l(θ (t+1) ) Equality in Jensen: g t (θ (t) ) = l(θ (t) ) So: l(θ (t) ) = g t (θ (t) ) g t (θ (t+1) ) l(θ (t+1) ) 7/1
Expectation - maximization algorithm EM algorithm: E-step: q (t+1) = arg max q F (q, θ (t) ) (i.e. q (t+1) = p(z X, θ (t) )) M-step: θ (t+1) = arg max θ F (q (t+1), θ) 8/1
EM algorithm - convergence We proved so far that: l(θ (0) ) l(θ (1) )... l(θ (t) )... But why does it converge to a stationary point? (Who guarantees no early stopping?) Proof: Let θ be the limit of the sequence defined by the EM algorithm. Then: θ = arg max θ g (θ), where g (θ) = F (p(z X, θ ), θ). This implies: θ g (θ ) = 0. Let h (θ) := l(θ) g (θ) = ( ) Z p(z X, θ ) log p(z X,θ) p(z X,θ ) Then, h (θ) 0, θ (since g is a lower bound of l) and h (θ ) = 0 (Jensen equality case) So, θ = arg min θ h (θ) θ h (θ ) = 0 So, θ l(θ ) = θ h (θ ) + θ g (θ ) = 0, q.e.d. 9/1
EM Applications Tired of too much math? :) Let s look at some cool applications of EM 10/1
Application 1 : Coin Flipping There are two coins A and B with θ A and θ B being the probability landing on Head when tossed. Do 5 rounds. In each round, select one coin uniformly at random and toss it 10 times then record the results. The observed data consists of 50 coin tosses. However, we don t know which coin was selected for a particular round. Estimate θ A and θ B. 11/1
Application 1 : Coin Flipping Let s start simple: One coin A with P(Y = H) = θ A 10 tosses: #H = x {0,..., 10}, #T = 10 x How to estimate θ A? Maximize what we see! Mathematically, maximize data (log-)likelihood: θ A = arg max θ A l(θ A ), where l(θ A ) := log P(X = x θ A ) P(X = x θa ) = θ x A (1 θ A) 10 x (note: fixed order of tosses) l(θ A ) = x log(θ A ) + (10 x) log(1 θ A ) Set derivative to 0: l θ A (θa ) = 0 θ A = x 10 Best ML distribution is the empirical distribution. 12/1
Application 1 : Coin Flipping Back to our original problem. Parameters θ = {θ A, θ B } 13/1
Application 1 : Coin Flipping Back to our original problem. Parameters θ = {θ A, θ B } Latent r.v. Z r - the coin selected in round r {1,..., 5}: p(z r = A) = p(z r = B) = 0.5 13/1
Application 1 : Coin Flipping Back to our original problem. Parameters θ = {θ A, θ B } Latent r.v. Z r - the coin selected in round r {1,..., 5}: p(z r = A) = p(z r = B) = 0.5 In each round r, the number of heads is x r. Associated r.v. X r. 13/1
Application 1 : Coin Flipping Back to our original problem. Parameters θ = {θ A, θ B } Latent r.v. Z r - the coin selected in round r {1,..., 5}: p(z r = A) = p(z r = B) = 0.5 In each round r, the number of heads is x r. Associated r.v. X r. p(x r = x r Z r = A; θ) = θ xr A (1 θ A) 10 xr 13/1
Application 1 : Coin Flipping Back to our original problem. Parameters θ = {θ A, θ B } Latent r.v. Z r - the coin selected in round r {1,..., 5}: p(z r = A) = p(z r = B) = 0.5 In each round r, the number of heads is x r. Associated r.v. X r. p(x r = x r Z r = A; θ) = θ xr A (1 θ A) 10 xr Bayes rule: p(z r = A x r ; θ) = θ xr A (1 θ A) 10 xr θ xr A (1 θ A) 10 xr +θ xr B (1 θ B) 10 xr 13/1
Application 1 : Coin Flipping Data likelihood (per one round): p(x r ; θ) = p(x r Z r = A; θ)p(z r = A) + p(x r Z r = B; θ)p(z r = B) = 0.5 ( θ xr A (1 θ A) 10 xr + θ xr B (1 θ B) ) 10 xr 14/1
Application 1 : Coin Flipping Data likelihood (per one round): p(x r ; θ) = p(x r Z r = A; θ)p(z r = A) + p(x r Z r = B; θ)p(z r = B) = 0.5 ( θ xr A (1 θ A) 10 xr + θ xr B (1 θ B) ) 10 xr Data log-likelihood (all rounds): l(θ) = log p(x ; θ) = 5 r=1 log p(x r ; θ) 14/1
Application 1 : Coin Flipping Data likelihood (per one round): p(x r ; θ) = p(x r Z r = A; θ)p(z r = A) + p(x r Z r = B; θ)p(z r = B) = 0.5 ( θ xr A (1 θ A) 10 xr + θ xr B (1 θ B) ) 10 xr Data log-likelihood (all rounds): l(θ) = log p(x ; θ) = 5 r=1 log p(x r ; θ) Cannot maximize log-likelihood directly (i.e. by setting gradient to zero). 14/1
Application 1 : Coin Flipping Data likelihood (per one round): p(x r ; θ) = p(x r Z r = A; θ)p(z r = A) + p(x r Z r = B; θ)p(z r = B) = 0.5 ( θ xr A (1 θ A) 10 xr + θ xr B (1 θ B) ) 10 xr Data log-likelihood (all rounds): l(θ) = log p(x ; θ) = 5 r=1 log p(x r ; θ) Cannot maximize log-likelihood directly (i.e. by setting gradient to zero). Instead, maximize EM lower bound on l(θ) (formalized last time). 14/1
Application 1 : Coin Flipping EM lower-bound per round (Jensen inequality): log p(x r ; θ) ( ) p(xr, Z r = c; θ) q r (Z r = c) log := F r (q r, θ) q r (Z r = c) c=a,b 15/1
Application 1 : Coin Flipping EM lower-bound per round (Jensen inequality): log p(x r ; θ) ( ) p(xr, Z r = c; θ) q r (Z r = c) log := F r (q r, θ) q r (Z r = c) c=a,b Expectation step: q r (Z r = c) = p(z r = c x r ; θ (t) ), r {1,..., 5} 15/1
Application 1 : Coin Flipping EM lower-bound per round (Jensen inequality): log p(x r ; θ) ( ) p(xr, Z r = c; θ) q r (Z r = c) log := F r (q r, θ) q r (Z r = c) c=a,b Expectation step: q r (Z r = c) = p(z r = c x r ; θ (t) ), r {1,..., 5} Maximization step: where g t (θ) = θ (t+1) = arg max g t (θ) θ 5 F r (p(z r = x r, θ (t) ), θ) r=1 15/1
Application 1 : Coin Flipping Maximization step: θ (t+1) = arg max θ 5 r=1 c=a,b p(z r = c x r, θ (t) ) log (p(x r, Z r = c; θ)) 16/1
Application 1 : Coin Flipping Maximization step: θ (t+1) = arg max θ Gradient: g t (θ) θ A = = 5 r=1 c=a,b p(z r = c x r, θ (t) ) log (p(x r, Z r = c; θ)) 5 p(z r = A x r, θ (t) ) log (p(x r, Z r = A; θ)) θ A 5 ( p(z r = A x r, θ (t) xr ) + 10 x ) r θ A 1 θ A r=1 r=1 16/1
Application 1 : Coin Flipping Maximization step: θ (t+1) = arg max θ Gradient: g t (θ) θ A = = 5 r=1 c=a,b p(z r = c x r, θ (t) ) log (p(x r, Z r = c; θ)) 5 p(z r = A x r, θ (t) ) log (p(x r, Z r = A; θ)) θ A 5 ( p(z r = A x r, θ (t) xr ) + 10 x ) r θ A 1 θ A r=1 r=1 Gradient set to 0 gives: θ (t+1) α (t) 5 A = p(z r = A x r, θ (t) )x r ; r=1 A = α(t) A α (t) A +β(t) A β (t) where 5 A = p(z r = A x r, θ (t) )(10 x r ) r=1 16/1
Application 1 : Coin Flipping Final algorithm: Iteration: t 0 Initialize parameters randomly: θ (0) A, θ(0) B (0, 1) Do until convergence: θ (t+1) A = α(t) A α (t) A +β(t) A θ (t+1) B = α(t) B α (t) B +β(t) B t t + 1 17/1
Application 2 : Topic Modelling Document representations: Used for classification, query retrieval, document similarity, etc. A document can be seen as a multi-set of words d = {(w i tf (w i ; d))} i=1, V R V Issues: high dimensionality, sparsity issues, potentially many infrequent words (with noisy estimated parameters) Alternative (compressed topic representation): topic distributions: d = {(t p(t d))} t=1,k R K K = num of topics K << V How to choose the number of topics K? Hyper-parameter: the one that gives the best performance on a validation set for the task at hand Minimize perplexity of seen words 18/1
Application 2 : Topic Modelling Model parameters (to be learned): π t := p(t d), a nt := p(w n t) 19/1
Application 2 : Topic Modelling Model parameters (to be learned): π t := p(t d), a nt := p(w n t) Log likelihood (one document): N N T l(π) = log p(w n d) = log π t a nt n=1 n=1 t=1 19/1
Application 2 : Topic Modelling Model parameters (to be learned): π t := p(t d), a nt := p(w n t) Log likelihood (one document): N N T l(π) = log p(w n d) = log π t a nt n=1 n=1 t=1 Iterative algorithm: keep a nt fixed, learn π t ; and reverse. 19/1
Application 2 : Topic Modelling Model parameters (to be learned): π t := p(t d), a nt := p(w n t) Log likelihood (one document): N N T l(π) = log p(w n d) = log π t a nt n=1 n=1 t=1 Iterative algorithm: keep a nt fixed, learn π t ; and reverse. We do here just the update of π t. The update of a nt is similar. 19/1
Application 2 : Topic Modelling Model parameters (to be learned): π t := p(t d), a nt := p(w n t) Log likelihood (one document): N N T l(π) = log p(w n d) = log π t a nt n=1 n=1 t=1 Iterative algorithm: keep a nt fixed, learn π t ; and reverse. We do here just the update of π t. The update of a nt is similar. Log-likelihood with Lagrange multipliers: ( N T T ) L(π, λ) = log π t a nt λ π t 1 n=1 t=1 t=1 19/1
Application 2 : Topic Modelling Iterative update algorithm. 20/1
Application 2 : Topic Modelling Iterative update algorithm. Latent variables Z are now the topics t. 20/1
Application 2 : Topic Modelling Iterative update algorithm. Latent variables Z are now the topics t. EM lower bound using Jensen: L(π, λ) F (q, π, λ) = where t q nt = 1, n N T n=1 t=1 [ q nt log π ] ( T ) t + log a nt λ π t 1 q nt t=1 20/1
Application 2 : Topic Modelling Iterative update algorithm. Latent variables Z are now the topics t. EM lower bound using Jensen: L(π, λ) F (q, π, λ) = where t q nt = 1, n E-step, iteration k: q (k) N T n=1 t=1 nt = π(k) t a nt t π(k) t a nt [ q nt log π ] ( T ) t + log a nt λ π t 1 q nt t=1 20/1
Application 2 : Topic Modelling Iterative update algorithm. Latent variables Z are now the topics t. EM lower bound using Jensen: L(π, λ) F (q, π, λ) = where t q nt = 1, n E-step, iteration k: q (k) M-step, iteration k: π (k+1) t N T n=1 t=1 nt = π(k) t a nt t π(k) t a nt = π(k) t N [ q nt log π ] ( T ) t + log a nt λ π t 1 q nt t=1 N n=1 a nt t π(k) t a nt 20/1
Questions? 21/1