Variational inference in the conjugate-exponential family Matthew J. Beal Work with Zoubin Ghahramani Gatsby Computational Neuroscience Unit August 2000
Abstract Variational inference in the conjugate-exponential family When learning models from data there is always the problem of over-fitting/generalisation performance. In a Bayesian approach this can be resolved by averaging over all possible settings of the model parameters. However for most interesting problems, averaging over models leads to intractable integrals and so approximations are necessary. Variational methods are becoming a widespread tool to approximate inference and learning in many graphical models: they are deterministic, generally fast, and can monotonically increase an objective function which transparently incorporates model complexity cost. I ll provide some theoretical results for the variational updates in a very general family of conjugate-exponential graphical models, and show how well-known algorithms (e.g. belief-propagation) can be readily incorporated in the variational updates. Some examples to illustrate the ideas: determining the most probable number of clusters and their intrinsic latent-space dimensionality in a Mixture of Factor Analysers model; and recovering the hidden state-space dimensionality in a Linear Dynamical System model. This is work with Zoubin Ghahramani.
Outline ffl Briefly review variational Bayesian learning. ffl Concentrate on conjugate-exponential models. ffl Variational Expectation-Maximisation (VEM). ffl Examples in Mixture of Factor Analysers (MFA) and Linear Dynamical Systems (LDS).
Bayesian approach Motivation ffl Maximum likelihood (ML) does not penalise complex models, which can a priori model a larger range of data sets. ffl Bayes does not suffer from overfitting if we integrated out all the parameters. ffl We should be able to do principled model comparison. Method ffl Express distributions over all parameters in the model. ffl Calculate the Evidence P (Y jm) = Z d P (Y j ;M)P ( jm): too simple P(Data) Data Sets Y "just right" too complex (adapted from D.J.C. MacKay) Problem ffl These integrals are computationally intractable.
Practical approaches ffl Laplace approximations: Appeal to Central Limit Theorem, making a Gaussian approximation about the maximum a posteriori parameter estimate. ffl Large sample approximations: e.g. BIC. ffl Markov chain Monte Carlo (MCMC): Guaranteed to converge in the limit. Many samples required for accurate results. Hard to assess convergence. ffl Variational approximations.
The variational method Let the hidden states be X and the parameters, ln P (Y ) = ln Z dx d P (Y; X; ) Z dx d Q(X; )ln P (Y; X; ) Q(X; ) : Approximate the intractable distribution over hidden states (X) and parameters ( ) with a simpler distribution. ln P (Y ) = ln Z dx d P (Y; X; ) Z dx d Q X (X)Q ( )ln = F(Q; Y ): P (Y; X; ) Q X (X)Q ( ) Maximisation of this lower bound leads to EM-like updates: E step Q Λ X (X) / exp hln P (X;Y j )i Q ( ) M step Q Λ ( ) / P ( )exphln P (X;Y j )i Q X (X) Equivalent to minimizing the KL-divergence between the approximating and true posteriors.
a bit more explanation F(Q; Y ) = = Z Z dx d Q X (X)Q ( )ln d Q ( ) Z " ln P ( ) Q ( ) + dx Q X (X)ln P (X; Y; ) Q X (X)Q ( ) P (X; Y j ) Q X (X) # : e.g. setting a derivative to zero: @F(Q; Y ) @Q ( ) / ln P ( ) 1 ln Q ( ) + hln P (X; Y j )i QX (X) E step Q Λ X (X) / exp hln P (X;Y j )i Q ( ) M step Q Λ ( ) / P ( )exphln P (X;Y j )i Q X (X) Question: What distributions can we make use of in our models?
Conjugate-Exponential models Consider variational Bayesian learning in models that satisfy: Condition (1). The complete data likelihood is in the exponential family: P (x; yj ) = f(x; y) g( )exp n o ffi( ) > u(x; y) where ffi( ) is the vector of natural parameters, and u and f and g are the functions that define the exponential family. Condition (2). The parameter prior is conjugate to the complete data likelihood: P ( j ; ν) = h( ; ν) g( ) exp n o ffi( ) > ν where and ν are hyperparameters of the prior. We call models that satify conditions (1) and (2) conjugate-exponential.
Exponential family models Some models that are in the exponential family: ffl Gaussian mixtures, ffl factor analysis, probabilistic PCA, ffl hidden Markov models and factorial HMMs, ffl linear dynamical systems and switching state-space models, ffl Boltzmann machines, ffl discrete-variable belief networks. Other as yet undreamt-of models can combine Gaussian, Gamma, Poisson, Dirichlet, Wishart, Multinomial and others. Some which are not in exponential family: ffl sigmoid belief networks, ffl independent components analysis Make approximations/changes to move into exponential family...
Theoretical Results Theorem 1 Given an iid data set Y = (y 1 ;:::y n ), if the model satisfies conditions (1) and (2), then at the maxima of F(Q; Y ) (minima of KL(QkP )): (a) Q ( ) is conjugate and of the form: Q ( ) = h(~ ; ~ν)g( ) ~ exp n o ffi( ) > ~ν where ~ = + n, ~ν = ν + P n i=1 u(x i ; y i ), and u(x i ; y i ) = hu(x i ; y i )i Q, using h i Q to denote expectation under Q. (b) Q X (X) = Q n i=1 Q xi (x i ) and Q xi (x i ) is of the same form as the known parameter posterior: Q xi (x i ) / f(x i ; y i )exp = P (x i jy i ; ffi( )) where ffi( ) = hffi( )i Q. n ffi( ) > u(x i ; y i ) o KEY points: (a) the approximate parameter posterior is of the same form as the prior; (b) the approximate hidden variable posterior is of the same form as the hidden variable posterior for known parameters.
Variational EM algorithm Since Q ( ) and Q xi (x i ) are coupled, (a) and (b) do not provide an analytic solution to the minimisation problem. VE Step: Compute the expected sufficient statistics t(y ) = P i u(x i ; y i ) under the hidden variable distributions Q xi (x i ). VM Step: Compute the expected natural parameters ffi( ) under the parameter distribution given by ~ and ~ν. Properties: ffl VE step has same complexity as corresponding E step. ffl Reduces to the EM algorithm if Q ( ) = ffi( Λ ).M step then involves re-estimation of Λ. ffl F increases monotonically, incorporates model complexity penalty.
Belief networks Corollary 1: Conjugate-Exponential Belief Networks. Let M be a conjugate-exponential model with hidden and visible variables z = (x; y) that satisfy a belief network factorisation. That is, each variable z j has parents z pj and P (zj ) = Q j P (z jjz pj ; ). Then the approximating joint distribution for M satisfies the same belief network factorisation: Q z (z) = Y j Q(z j jz pj ; ~ ) where the conditional distributions have exactly the same form as those in the original model but with natural parameters ffi(~ ) = ffi( ). Furthermore, with the modified parameters ~, the expectations under the approximating posterior Q x (x) / Q z (z) required for the VE Step can be obtained by applying the belief propagation algorithm if the network is singly connected and the junction tree algorithm if the network is multiply-connected.
Markov networks Theorem 2: Markov Networks. Let M be a model with hidden and visible variables z = (x; y) that satisfy a Markov network factorisation. That is, the joint density can be written as a product of clique-potentials ψ j, P (zj ) = g( ) Q j ψ j(c j ; ), where each clique C j is a subset of the variables in z. Then the approximating joint distribution for M satisfies the same Markov network factorisation: Q z (z) = ~g Y j ψ j (C j ) where ψ j (C j ) = exp n hln ψj (C j ; )i Q o are new clique potentials obtained by averaging over Q ( ), and ~g is a normalisation constant. Furthermore, the expectations under the approximating posterior Q x (x) required for the VE Step can be obtained by applying the junction tree algorithm. Corollary 2: Conjugate-Exponential Markov Networks. Let M be a conjugate-exponential Markov network over the variables in z. Then the approximating joint distribution for M is given by Q z (z) = ~g Q j ψ j(c j ; ~ ), where the clique potentials have exactly the same form as those in the original model but with natural parameters ffi(~ ) = ffi( ).
Mixture of factor analysers p-dim. vector y generated from k unobserved independent Gaussian sources, x, then corrupted with Gaussian noise, r, with diagonal covariance matrix Ψ: y = Λx + r. Marginal density of y is Gaussian with zero mean and covariance ΛΛ T + Ψ. a,b ν 1 ν 2... ν S Λ 1 Λ 2... Λ S α s n π Ψ y n n=1...n x n ffl Complete data likelihood is in exponential family. ffl Choose conjugate priors: (ν ο Gamma, Λ οgaussian, ß οdirichlet). VE: posteriors over hidden states calculated as usual. VM: posteriors over parameters have same form as priors.
Synthetic example True data: 6 Gaussian clusters with dimensions: (1 7 4 3 2 2) embedded in 10 dimensions. Inferred structure: ffl Finds the clusters and their dimensionalities. ffl Model complexity reduces in line with lack of data support. number of points intrinsic dimensionalities per cluster 1 7 4 3 2 2 8 2 1 8 1 2 16 1 4 2 32 1 6 3 3 2 2 64 1 7 4 3 2 2 128 1 7 4 3 2 2
Linear Dynamical Systems X 1 X 2 X 3 X T Y 1 Y 2 Y 3 Y T ffl Sequence of D-dimensional real-valued observed vectors y 1:T. ffl Assume at each time step t, y t was generated from a k-dimensional real-valued hidden state variable x t, and that the sequence of x 1:T define a first order Markov process. ffl The joint probability is given by P (x 1:T ; y 1:T ) = P (x 1 )P (y 1 jx 1 ) TY t=2 P (x t jx t 1 )P (y t jx t ): ffl If transition and output functions are linear, time-invariant, and noise distributions are Gaussian, this is a Linear-Gaussian state-space model: x t = Ax t 1 + w t ; y t = Cx t + r t :
Model selection using ARD EM underlying Dynamics matrix, A Output matrix, C ffl Gaussian priors on the entries in the columns of these matrices. P (a i jff) = N (0; diag(ff) 1 ); P (c i jfi) = N (0; diag(ρ i fi) 1 ) ffl ff and fi are vectors of hyperparameters, to be optimised during learning. Some ff,fi! 1. ffl Empty columns signify extinct hidden dimensions for: modelling hidden dynamics, for modelling output covariance structure.
LDS variational approximation Parameters of the model = (A; C; ρ), hidden states x 1:T. ln P (Y ) = ln Z da dc dρ dx 1:t P (A; C; ρ; x 1:T ; y 1:T ): Lower bound: Jensen s inequality ln P (Y ) F(Q; Y ) = Z da dc dρ dx 1:T Q(A; C; ρ; x 1:T ) ln P (A; C; ρ; x 1:T ; y 1:T ) : Q(A; C; ρ; x 1:T ) Factored approximation to the true posterior Q(A; C; ρ; x 1:T ) = Q A (A)Q Cρ (C; ρ)q x (x 1:T ): This factorisation falls out from the initial assumptions and the conditional independencies in the model. Priors: ffl Sensor precisions, ρ, given conjugate gamma priors. ffl Transition and output matrices given conjugate zero-mean Gaussian priors ARD.
Optimal Q( ) distributions: LDS Complete data likelihood is of conjugate-exponential form: iterative analytic maximisations. Dynamics matrix Q A (A): A ο N (a i ; a i ; ± A i ) a i = S > ± A i ; ± A i = (diag(ff)+w ) 1 Noise covariance Q ρ (ρ): ρ i ο Gamma(ρ i ;~a;~b) ~a = a + T=2; ~b = b + g i =2 Output matrix Q C (Cjρ i ): C ο N (c i ; c i ; ± C i ) c i = ρ i U i ± C i ; ± C i = (diag(fi)+w 0 ) 1 =ρ i Hidden state Q x (x 1:T ): x t ο N (x t ; (T ) t ; Ψ (T ) t ) f (T ) t ; Ψ (T ) t g = KalmanSmoother(Q ACρ ; y 1:T ) Hyperparameters ff; fi: 1=ff = ha ai QA 1=fi = hc ci QCρ S = g i = TX TX t=2 X T 1 hx t 1 x > i; W = t hx t x > i; t t=1 y 2 ti U i(diag(fi) + W 0 )U > i ; U i = t=1 t=1 W 0 = W + hx t x > t i TX y ti hx > t i;
Method overview 1. Randomly initialise the approximate posteriors; 2. Update each posterior according to variational E- & M- steps ffl Q A (A) Q Cρ (Cjρ) Qρ(ρ) Q x (x 1:T ); 3. Update the hyperparameters: ff and fi; 4. Repeat steps 2-3 until convergence.
Variational E Step for LDS X 1 X 2 X 3 X T Y 1 Y 2 Y 3 Y T ffl Conjugate-exponential singly connected belief network, ffl (C1) use Belief Propagation, which for LDSs is the Kalman smoother: Forward: infer state x t given past observations, Backward: infer state x t given future observations. ffl Inference using an ensemble of parameters is possible, using just a single setting of the parameters. Required averages: hρ i c i i; hρ i c i c i> i; hai; ha > Ai ffl The result is a recursive algorithm very similar to the Kalman smoothing propagation.
Synthetic experiments Generated a variety of state-space models and looked the recovered hidden structure. Observed data is ffl 10-dim output ffl 200 time-steps ffl True model: 3-dim static state-space. Inferred model: A C X X Y ffl True model: 3-dim dynamical state-space. Inferred model: X X Y ffl True model: 4-dim state-space, of which 3 dynamical. Inferred model: X X Y
Degradation with data support True model: ffl 6-dim state-space ffl 10-dim output Complexity of recovered structure decreases with decreasing data support (400 to 10 time-steps). Inferred models: 400 350 250 X X Y X X Y X X Y 100 30 20 10 X X Y X X Y X X Y X X Y
Summary ffl Bayesian learning avoids overfitting and can be used to learn model structure. ffl Tractable learning using variational approximations. ffl Conjugate-exponential familes of models. ffl Variational EM and integration of existing algorithms. ffl Examples learning structure in MFA and LDS models. Issues & Extensions ffl Quality of the variational approximations? ffl How best to bring a model into exponential family? ffl Extensions to other models, e.g. switching state-space models [Ueda], is straightforward. ffl Problems integrating existing algorithms into the VE step. ffl Implementation on a general graphical model.
Optimal Q( ) distributions: MFA Q(x n js n ) sn (x n;s ; ± s ) : ± s 1 = hλ st Ψ 1 Λ s i Q(Λs ) + I Q(Λ s q ) sn (Λs q; ± q;s ) : Λ s q = " ± q;s 1 = Ψ 1 qq Ψ 1 N X n=1 NX n=1 x n;s = ± s Λ st Ψ 1 y n Q(s n )y n x n;st ± q;s # Q(s n )hx n x nt i Q(xn js n ) + diaghν s i Q(ν s ) q Q(ν s l ) sgamma(as l ;b s l ) : as l = a + p 2 b s l = b + 1 2 px q=1 hλ s ql2 iq(λs ) Q(ß) sdirichlet(!u) :!u s = ff S + N X n=1 Q(s n ) ln Q(s n ) = [ψ(!u s ) ψ(!)] + 1 2 ln j±s j + hln P (y n jx n ;s n ; Λ s ; Ψ)i Q(xn js n )Q(Λ s ) + c Note that the optimal distributions Q(Λ s ) have block diagonal covariance structure. So even though each Λ s is a p q matrix, its covariance only has O(pq 2 ) parameters. Differentiating F with respect to the parameters, a and b, of the precision prior we get fixed point equations ψ(a) = hln νi + ln b and b = a=hνi. Similarly the fixed point for the hyperparameters for the Dirichlet prior is ψ(ff) ψ(ff=s) + P [ψ(!u s ) ψ(!)] =S = 0.
Sampling from Variational Approximations Sampling m s Q( ) gives us estimates of: ffl The Exact Predictive Density: P (yjy ) = = Z Z ß M X m=1 d P (yj )P ( jy ) d Q( )P (yj ) P ( jy ) Q( ) P (yj m)!m weights:!m = Ω 1 P ( m ;Y ) Q( m ) ; withω P s:t:! m = 1 ffl The True Evidence: P (Y jm) = ffl The KL Divergence: Z d Q( ) P ( ; Y ) Q( ) = hω!i KL(Q( )kp ( jy )) = lnh!i hln!i:
Variational Bayes & Ensemble Learning ffl multilayer perceptrons (Hinton & van Camp, 1993) ffl mixture of experts (Waterhouse, MacKay & Robinson, 1996) ffl hidden Markov models (MacKay, 1995) ffl other work by Jaakkola, Barber, Bishop, Tipping, etc Examples of VB Learning Model Structure Model learning has been treated with variational Bayesian techniques for: ffl mixtures of factor analysers (Ghahramani & Beal, 1999) ffl mixtures of Gaussians (Attias, 1999) ffl independent components analysis (Attias, 1999; Miskin & MacKay, 2000; Valpola 2000) ffl principal components analysis (Bishop, 1999) ffl linear dynamical systems (Ghahramani & Beal, 2000) ffl mixture of experts (Ueda & Ghahramani, 2000) ffl hidden Markov models (Ueda & Ghahramani, in prep) (compiled by Zoubin)
0.3 Evolution of F, true evidence and KL-divergence x 10 5 1.1 Train and Test fractional error 0.25 0.2 0.15 0.1 0.05 train error test error lnp(y) train F(π,Λ) train F(θ) F 1.15 1.2 1.25 1.3 0 0 1000 2000 3000 4000 5000 6000 7000 8000 Iteration
Required expectations for linear-gaussian state-space models a + T=2 defining G 0 = diag b + G=2 hc > R 1 Ci = diagfi + W 0 1 hc > R 1 i = diagfi + W 0 1 UG 0 hai = S > (diagff + W ) 1 ha > Ai = k (diagff + W ) 1 + hai > hai hp + UG 0 U > diagfi + W 0 1 i P Ω ff P T > T 1 Ω ff P > T where S = t=2 xt 1 x t, W = t=1 xt x t, U i = t=1 y tihx > t i and W 0 = W + Ω x T x T > ff. Kalman smoothing propagation Forward recursion: ± Λ t = ± 1 t Backward recursion: + ha > Ai 1 μ t = ± t hc> R 1 iy t + hai± Λ t 1 ± 1 t 1 μ t 1 ± t = I + hc > R 1 Ci hai± Λ t 1 hai>λ 1 t = Ψ t h± 1 t μ t + hai > Ψ 1 t+1 + hai> 1 hai±λ t Ψ 1 Ψ t = h ± Λ 1 t hai > Ψ 1 t+1 t+1 hai±λ t ± 1 t Λ μ t Λ t+1 + hai±λ t hai> 1 hai i 1 cov(x t ; x t+1 ) = Ψ 1 t+1 + hai±λ t hai> hai > ± Λ t haiλ 1