Variational inference in the conjugate-exponential family

Similar documents
Unsupervised Learning

CSC2535: Computation in Neural Networks Lecture 7: Variational Bayesian Learning & Model Selection

Variational Scoring of Graphical Model Structures

Statistical Approaches to Learning and Discovery

STA 4273H: Statistical Machine Learning

Machine Learning Summer School

13: Variational inference II

Lecture 6: Graphical Models: Learning

Pattern Recognition and Machine Learning

Variational Principal Components

p L yi z n m x N n xi

Approximate Inference Part 1 of 2

Approximate Inference Part 1 of 2

Variational Bayesian Dirichlet-Multinomial Allocation for Exponential Family Mixtures

Recent Advances in Bayesian Inference Techniques

STA 4273H: Statistical Machine Learning

The Variational Bayesian EM Algorithm for Incomplete Data: with Application to Scoring Graphical Model Structures

Bayesian Inference Course, WTCN, UCL, March 2013

Inference and estimation in probabilistic time series models

Bayesian Machine Learning - Lecture 7

The Variational Kalman Smoother

Linear Dynamical Systems

Bayesian Learning. HT2015: SC4 Statistical Data Mining and Machine Learning. Maximum Likelihood Principle. The Bayesian Learning Framework

Variational Learning : From exponential families to multilinear systems

Lecture 13 : Variational Inference: Mean Field Approximation

Quantitative Biology II Lecture 4: Variational Methods

Variational Methods in Bayesian Deconvolution

Bayesian Methods for Machine Learning

Machine Learning Techniques for Computer Vision

U-Likelihood and U-Updating Algorithms: Statistical Inference in Latent Variable Models

Variational Message Passing. By John Winn, Christopher M. Bishop Presented by Andy Miller

Probabilistic Graphical Models

Towards a Bayesian model for Cyber Security

Bayesian Learning in Undirected Graphical Models

STA 414/2104: Machine Learning

Integrated Non-Factorized Variational Inference

Variational Bayesian Linear Dynamical Systems

Statistical Approaches to Learning and Discovery

Variational Message Passing

Bayesian Hidden Markov Models and Extensions

Variational Inference (11/04/13)

2 XAVIER GIANNAKOPOULOS AND HARRI VALPOLA In many situations the measurements form a sequence in time and then it is reasonable to assume that the fac

Probabilistic Graphical Models

Part 1: Expectation Propagation

Variational Bayesian Hidden Markov Models

The Variational Gaussian Approximation Revisited

Expectation Propagation Algorithm

Machine Learning 4771

The connection of dropout and Bayesian statistics

Lecture 1a: Basic Concepts and Recaps

Introduction to Machine Learning

output dimension input dimension Gaussian evidence Gaussian Gaussian evidence evidence from t +1 inputs and outputs at time t x t+2 x t-1 x t+1

Sequential Monte Carlo Methods for Bayesian Computation

The Particle Filter. PD Dr. Rudolph Triebel Computer Vision Group. Machine Learning for Computer Vision

Latent Variable Models and EM algorithm

STA 4273H: Sta-s-cal Machine Learning

Probabilistic & Unsupervised Learning

STA414/2104 Statistical Methods for Machine Learning II

Lecture 6: April 19, 2002

Computer Vision Group Prof. Daniel Cremers. 6. Mixture Models and Expectation-Maximization

Lecture 16 Deep Neural Generative Models

Machine Learning. Probability Basics. Marc Toussaint University of Stuttgart Summer 2014

Chris Bishop s PRML Ch. 8: Graphical Models

Lecture 1b: Linear Models for Regression

Graphical Models and Kernel Methods

ECE521 Tutorial 11. Topic Review. ECE521 Winter Credits to Alireza Makhzani, Alex Schwing, Rich Zemel and TAs for slides. ECE521 Tutorial 11 / 4

Pattern Recognition and Machine Learning. Bishop Chapter 2: Probability Distributions

Week 3: The EM algorithm

An introduction to Variational calculus in Machine Learning

Variational Message Passing

Lecture 9: PGM Learning

Probabilistic Graphical Models

Study Notes on the Latent Dirichlet Allocation

Another Walkthrough of Variational Bayes. Bevan Jones Machine Learning Reading Group Macquarie University

VIBES: A Variational Inference Engine for Bayesian Networks

Lecture 10. Announcement. Mixture Models II. Topics of This Lecture. This Lecture: Advanced Machine Learning. Recap: GMMs as Latent Variable Models

An Unsupervised Ensemble Learning Method for Nonlinear Dynamic State-Space Models

Probabilistic Graphical Models

Bayesian Networks: Construction, Inference, Learning and Causal Interpretation. Volker Tresp Summer 2016

Bayesian Learning in Undirected Graphical Models

Graphical models: parameter learning

an introduction to bayesian inference

Probabilistic Graphical Models

STA414/2104. Lecture 11: Gaussian Processes. Department of Statistics

Bayesian Inference: Principles and Practice 3. Sparse Bayesian Models and the Relevance Vector Machine

PATTERN RECOGNITION AND MACHINE LEARNING CHAPTER 13: SEQUENTIAL DATA

Artificial Intelligence

Factor Analysis (10/2/13)

Widths. Center Fluctuations. Centers. Centers. Widths

Introduction to Machine Learning

Type II variational methods in Bayesian estimation

Markov Chain Monte Carlo Methods for Stochastic Optimization

bound on the likelihood through the use of a simpler variational approximating distribution. A lower bound is particularly useful since maximization o

Particle Filtering Approaches for Dynamic Stochastic Optimization

Lecture 4: Probabilistic Learning. Estimation Theory. Classification with Probability Distributions

Density Estimation. Seungjin Choi

NON-LINEAR NOISE ADAPTIVE KALMAN FILTERING VIA VARIATIONAL BAYES

Probabilistic and Bayesian Machine Learning

PMR Learning as Inference

Probabilistic Graphical Models

Transcription:

Variational inference in the conjugate-exponential family Matthew J. Beal Work with Zoubin Ghahramani Gatsby Computational Neuroscience Unit August 2000

Abstract Variational inference in the conjugate-exponential family When learning models from data there is always the problem of over-fitting/generalisation performance. In a Bayesian approach this can be resolved by averaging over all possible settings of the model parameters. However for most interesting problems, averaging over models leads to intractable integrals and so approximations are necessary. Variational methods are becoming a widespread tool to approximate inference and learning in many graphical models: they are deterministic, generally fast, and can monotonically increase an objective function which transparently incorporates model complexity cost. I ll provide some theoretical results for the variational updates in a very general family of conjugate-exponential graphical models, and show how well-known algorithms (e.g. belief-propagation) can be readily incorporated in the variational updates. Some examples to illustrate the ideas: determining the most probable number of clusters and their intrinsic latent-space dimensionality in a Mixture of Factor Analysers model; and recovering the hidden state-space dimensionality in a Linear Dynamical System model. This is work with Zoubin Ghahramani.

Outline ffl Briefly review variational Bayesian learning. ffl Concentrate on conjugate-exponential models. ffl Variational Expectation-Maximisation (VEM). ffl Examples in Mixture of Factor Analysers (MFA) and Linear Dynamical Systems (LDS).

Bayesian approach Motivation ffl Maximum likelihood (ML) does not penalise complex models, which can a priori model a larger range of data sets. ffl Bayes does not suffer from overfitting if we integrated out all the parameters. ffl We should be able to do principled model comparison. Method ffl Express distributions over all parameters in the model. ffl Calculate the Evidence P (Y jm) = Z d P (Y j ;M)P ( jm): too simple P(Data) Data Sets Y "just right" too complex (adapted from D.J.C. MacKay) Problem ffl These integrals are computationally intractable.

Practical approaches ffl Laplace approximations: Appeal to Central Limit Theorem, making a Gaussian approximation about the maximum a posteriori parameter estimate. ffl Large sample approximations: e.g. BIC. ffl Markov chain Monte Carlo (MCMC): Guaranteed to converge in the limit. Many samples required for accurate results. Hard to assess convergence. ffl Variational approximations.

The variational method Let the hidden states be X and the parameters, ln P (Y ) = ln Z dx d P (Y; X; ) Z dx d Q(X; )ln P (Y; X; ) Q(X; ) : Approximate the intractable distribution over hidden states (X) and parameters ( ) with a simpler distribution. ln P (Y ) = ln Z dx d P (Y; X; ) Z dx d Q X (X)Q ( )ln = F(Q; Y ): P (Y; X; ) Q X (X)Q ( ) Maximisation of this lower bound leads to EM-like updates: E step Q Λ X (X) / exp hln P (X;Y j )i Q ( ) M step Q Λ ( ) / P ( )exphln P (X;Y j )i Q X (X) Equivalent to minimizing the KL-divergence between the approximating and true posteriors.

a bit more explanation F(Q; Y ) = = Z Z dx d Q X (X)Q ( )ln d Q ( ) Z " ln P ( ) Q ( ) + dx Q X (X)ln P (X; Y; ) Q X (X)Q ( ) P (X; Y j ) Q X (X) # : e.g. setting a derivative to zero: @F(Q; Y ) @Q ( ) / ln P ( ) 1 ln Q ( ) + hln P (X; Y j )i QX (X) E step Q Λ X (X) / exp hln P (X;Y j )i Q ( ) M step Q Λ ( ) / P ( )exphln P (X;Y j )i Q X (X) Question: What distributions can we make use of in our models?

Conjugate-Exponential models Consider variational Bayesian learning in models that satisfy: Condition (1). The complete data likelihood is in the exponential family: P (x; yj ) = f(x; y) g( )exp n o ffi( ) > u(x; y) where ffi( ) is the vector of natural parameters, and u and f and g are the functions that define the exponential family. Condition (2). The parameter prior is conjugate to the complete data likelihood: P ( j ; ν) = h( ; ν) g( ) exp n o ffi( ) > ν where and ν are hyperparameters of the prior. We call models that satify conditions (1) and (2) conjugate-exponential.

Exponential family models Some models that are in the exponential family: ffl Gaussian mixtures, ffl factor analysis, probabilistic PCA, ffl hidden Markov models and factorial HMMs, ffl linear dynamical systems and switching state-space models, ffl Boltzmann machines, ffl discrete-variable belief networks. Other as yet undreamt-of models can combine Gaussian, Gamma, Poisson, Dirichlet, Wishart, Multinomial and others. Some which are not in exponential family: ffl sigmoid belief networks, ffl independent components analysis Make approximations/changes to move into exponential family...

Theoretical Results Theorem 1 Given an iid data set Y = (y 1 ;:::y n ), if the model satisfies conditions (1) and (2), then at the maxima of F(Q; Y ) (minima of KL(QkP )): (a) Q ( ) is conjugate and of the form: Q ( ) = h(~ ; ~ν)g( ) ~ exp n o ffi( ) > ~ν where ~ = + n, ~ν = ν + P n i=1 u(x i ; y i ), and u(x i ; y i ) = hu(x i ; y i )i Q, using h i Q to denote expectation under Q. (b) Q X (X) = Q n i=1 Q xi (x i ) and Q xi (x i ) is of the same form as the known parameter posterior: Q xi (x i ) / f(x i ; y i )exp = P (x i jy i ; ffi( )) where ffi( ) = hffi( )i Q. n ffi( ) > u(x i ; y i ) o KEY points: (a) the approximate parameter posterior is of the same form as the prior; (b) the approximate hidden variable posterior is of the same form as the hidden variable posterior for known parameters.

Variational EM algorithm Since Q ( ) and Q xi (x i ) are coupled, (a) and (b) do not provide an analytic solution to the minimisation problem. VE Step: Compute the expected sufficient statistics t(y ) = P i u(x i ; y i ) under the hidden variable distributions Q xi (x i ). VM Step: Compute the expected natural parameters ffi( ) under the parameter distribution given by ~ and ~ν. Properties: ffl VE step has same complexity as corresponding E step. ffl Reduces to the EM algorithm if Q ( ) = ffi( Λ ).M step then involves re-estimation of Λ. ffl F increases monotonically, incorporates model complexity penalty.

Belief networks Corollary 1: Conjugate-Exponential Belief Networks. Let M be a conjugate-exponential model with hidden and visible variables z = (x; y) that satisfy a belief network factorisation. That is, each variable z j has parents z pj and P (zj ) = Q j P (z jjz pj ; ). Then the approximating joint distribution for M satisfies the same belief network factorisation: Q z (z) = Y j Q(z j jz pj ; ~ ) where the conditional distributions have exactly the same form as those in the original model but with natural parameters ffi(~ ) = ffi( ). Furthermore, with the modified parameters ~, the expectations under the approximating posterior Q x (x) / Q z (z) required for the VE Step can be obtained by applying the belief propagation algorithm if the network is singly connected and the junction tree algorithm if the network is multiply-connected.

Markov networks Theorem 2: Markov Networks. Let M be a model with hidden and visible variables z = (x; y) that satisfy a Markov network factorisation. That is, the joint density can be written as a product of clique-potentials ψ j, P (zj ) = g( ) Q j ψ j(c j ; ), where each clique C j is a subset of the variables in z. Then the approximating joint distribution for M satisfies the same Markov network factorisation: Q z (z) = ~g Y j ψ j (C j ) where ψ j (C j ) = exp n hln ψj (C j ; )i Q o are new clique potentials obtained by averaging over Q ( ), and ~g is a normalisation constant. Furthermore, the expectations under the approximating posterior Q x (x) required for the VE Step can be obtained by applying the junction tree algorithm. Corollary 2: Conjugate-Exponential Markov Networks. Let M be a conjugate-exponential Markov network over the variables in z. Then the approximating joint distribution for M is given by Q z (z) = ~g Q j ψ j(c j ; ~ ), where the clique potentials have exactly the same form as those in the original model but with natural parameters ffi(~ ) = ffi( ).

Mixture of factor analysers p-dim. vector y generated from k unobserved independent Gaussian sources, x, then corrupted with Gaussian noise, r, with diagonal covariance matrix Ψ: y = Λx + r. Marginal density of y is Gaussian with zero mean and covariance ΛΛ T + Ψ. a,b ν 1 ν 2... ν S Λ 1 Λ 2... Λ S α s n π Ψ y n n=1...n x n ffl Complete data likelihood is in exponential family. ffl Choose conjugate priors: (ν ο Gamma, Λ οgaussian, ß οdirichlet). VE: posteriors over hidden states calculated as usual. VM: posteriors over parameters have same form as priors.

Synthetic example True data: 6 Gaussian clusters with dimensions: (1 7 4 3 2 2) embedded in 10 dimensions. Inferred structure: ffl Finds the clusters and their dimensionalities. ffl Model complexity reduces in line with lack of data support. number of points intrinsic dimensionalities per cluster 1 7 4 3 2 2 8 2 1 8 1 2 16 1 4 2 32 1 6 3 3 2 2 64 1 7 4 3 2 2 128 1 7 4 3 2 2

Linear Dynamical Systems X 1 X 2 X 3 X T Y 1 Y 2 Y 3 Y T ffl Sequence of D-dimensional real-valued observed vectors y 1:T. ffl Assume at each time step t, y t was generated from a k-dimensional real-valued hidden state variable x t, and that the sequence of x 1:T define a first order Markov process. ffl The joint probability is given by P (x 1:T ; y 1:T ) = P (x 1 )P (y 1 jx 1 ) TY t=2 P (x t jx t 1 )P (y t jx t ): ffl If transition and output functions are linear, time-invariant, and noise distributions are Gaussian, this is a Linear-Gaussian state-space model: x t = Ax t 1 + w t ; y t = Cx t + r t :

Model selection using ARD EM underlying Dynamics matrix, A Output matrix, C ffl Gaussian priors on the entries in the columns of these matrices. P (a i jff) = N (0; diag(ff) 1 ); P (c i jfi) = N (0; diag(ρ i fi) 1 ) ffl ff and fi are vectors of hyperparameters, to be optimised during learning. Some ff,fi! 1. ffl Empty columns signify extinct hidden dimensions for: modelling hidden dynamics, for modelling output covariance structure.

LDS variational approximation Parameters of the model = (A; C; ρ), hidden states x 1:T. ln P (Y ) = ln Z da dc dρ dx 1:t P (A; C; ρ; x 1:T ; y 1:T ): Lower bound: Jensen s inequality ln P (Y ) F(Q; Y ) = Z da dc dρ dx 1:T Q(A; C; ρ; x 1:T ) ln P (A; C; ρ; x 1:T ; y 1:T ) : Q(A; C; ρ; x 1:T ) Factored approximation to the true posterior Q(A; C; ρ; x 1:T ) = Q A (A)Q Cρ (C; ρ)q x (x 1:T ): This factorisation falls out from the initial assumptions and the conditional independencies in the model. Priors: ffl Sensor precisions, ρ, given conjugate gamma priors. ffl Transition and output matrices given conjugate zero-mean Gaussian priors ARD.

Optimal Q( ) distributions: LDS Complete data likelihood is of conjugate-exponential form: iterative analytic maximisations. Dynamics matrix Q A (A): A ο N (a i ; a i ; ± A i ) a i = S > ± A i ; ± A i = (diag(ff)+w ) 1 Noise covariance Q ρ (ρ): ρ i ο Gamma(ρ i ;~a;~b) ~a = a + T=2; ~b = b + g i =2 Output matrix Q C (Cjρ i ): C ο N (c i ; c i ; ± C i ) c i = ρ i U i ± C i ; ± C i = (diag(fi)+w 0 ) 1 =ρ i Hidden state Q x (x 1:T ): x t ο N (x t ; (T ) t ; Ψ (T ) t ) f (T ) t ; Ψ (T ) t g = KalmanSmoother(Q ACρ ; y 1:T ) Hyperparameters ff; fi: 1=ff = ha ai QA 1=fi = hc ci QCρ S = g i = TX TX t=2 X T 1 hx t 1 x > i; W = t hx t x > i; t t=1 y 2 ti U i(diag(fi) + W 0 )U > i ; U i = t=1 t=1 W 0 = W + hx t x > t i TX y ti hx > t i;

Method overview 1. Randomly initialise the approximate posteriors; 2. Update each posterior according to variational E- & M- steps ffl Q A (A) Q Cρ (Cjρ) Qρ(ρ) Q x (x 1:T ); 3. Update the hyperparameters: ff and fi; 4. Repeat steps 2-3 until convergence.

Variational E Step for LDS X 1 X 2 X 3 X T Y 1 Y 2 Y 3 Y T ffl Conjugate-exponential singly connected belief network, ffl (C1) use Belief Propagation, which for LDSs is the Kalman smoother: Forward: infer state x t given past observations, Backward: infer state x t given future observations. ffl Inference using an ensemble of parameters is possible, using just a single setting of the parameters. Required averages: hρ i c i i; hρ i c i c i> i; hai; ha > Ai ffl The result is a recursive algorithm very similar to the Kalman smoothing propagation.

Synthetic experiments Generated a variety of state-space models and looked the recovered hidden structure. Observed data is ffl 10-dim output ffl 200 time-steps ffl True model: 3-dim static state-space. Inferred model: A C X X Y ffl True model: 3-dim dynamical state-space. Inferred model: X X Y ffl True model: 4-dim state-space, of which 3 dynamical. Inferred model: X X Y

Degradation with data support True model: ffl 6-dim state-space ffl 10-dim output Complexity of recovered structure decreases with decreasing data support (400 to 10 time-steps). Inferred models: 400 350 250 X X Y X X Y X X Y 100 30 20 10 X X Y X X Y X X Y X X Y

Summary ffl Bayesian learning avoids overfitting and can be used to learn model structure. ffl Tractable learning using variational approximations. ffl Conjugate-exponential familes of models. ffl Variational EM and integration of existing algorithms. ffl Examples learning structure in MFA and LDS models. Issues & Extensions ffl Quality of the variational approximations? ffl How best to bring a model into exponential family? ffl Extensions to other models, e.g. switching state-space models [Ueda], is straightforward. ffl Problems integrating existing algorithms into the VE step. ffl Implementation on a general graphical model.

Optimal Q( ) distributions: MFA Q(x n js n ) sn (x n;s ; ± s ) : ± s 1 = hλ st Ψ 1 Λ s i Q(Λs ) + I Q(Λ s q ) sn (Λs q; ± q;s ) : Λ s q = " ± q;s 1 = Ψ 1 qq Ψ 1 N X n=1 NX n=1 x n;s = ± s Λ st Ψ 1 y n Q(s n )y n x n;st ± q;s # Q(s n )hx n x nt i Q(xn js n ) + diaghν s i Q(ν s ) q Q(ν s l ) sgamma(as l ;b s l ) : as l = a + p 2 b s l = b + 1 2 px q=1 hλ s ql2 iq(λs ) Q(ß) sdirichlet(!u) :!u s = ff S + N X n=1 Q(s n ) ln Q(s n ) = [ψ(!u s ) ψ(!)] + 1 2 ln j±s j + hln P (y n jx n ;s n ; Λ s ; Ψ)i Q(xn js n )Q(Λ s ) + c Note that the optimal distributions Q(Λ s ) have block diagonal covariance structure. So even though each Λ s is a p q matrix, its covariance only has O(pq 2 ) parameters. Differentiating F with respect to the parameters, a and b, of the precision prior we get fixed point equations ψ(a) = hln νi + ln b and b = a=hνi. Similarly the fixed point for the hyperparameters for the Dirichlet prior is ψ(ff) ψ(ff=s) + P [ψ(!u s ) ψ(!)] =S = 0.

Sampling from Variational Approximations Sampling m s Q( ) gives us estimates of: ffl The Exact Predictive Density: P (yjy ) = = Z Z ß M X m=1 d P (yj )P ( jy ) d Q( )P (yj ) P ( jy ) Q( ) P (yj m)!m weights:!m = Ω 1 P ( m ;Y ) Q( m ) ; withω P s:t:! m = 1 ffl The True Evidence: P (Y jm) = ffl The KL Divergence: Z d Q( ) P ( ; Y ) Q( ) = hω!i KL(Q( )kp ( jy )) = lnh!i hln!i:

Variational Bayes & Ensemble Learning ffl multilayer perceptrons (Hinton & van Camp, 1993) ffl mixture of experts (Waterhouse, MacKay & Robinson, 1996) ffl hidden Markov models (MacKay, 1995) ffl other work by Jaakkola, Barber, Bishop, Tipping, etc Examples of VB Learning Model Structure Model learning has been treated with variational Bayesian techniques for: ffl mixtures of factor analysers (Ghahramani & Beal, 1999) ffl mixtures of Gaussians (Attias, 1999) ffl independent components analysis (Attias, 1999; Miskin & MacKay, 2000; Valpola 2000) ffl principal components analysis (Bishop, 1999) ffl linear dynamical systems (Ghahramani & Beal, 2000) ffl mixture of experts (Ueda & Ghahramani, 2000) ffl hidden Markov models (Ueda & Ghahramani, in prep) (compiled by Zoubin)

0.3 Evolution of F, true evidence and KL-divergence x 10 5 1.1 Train and Test fractional error 0.25 0.2 0.15 0.1 0.05 train error test error lnp(y) train F(π,Λ) train F(θ) F 1.15 1.2 1.25 1.3 0 0 1000 2000 3000 4000 5000 6000 7000 8000 Iteration

Required expectations for linear-gaussian state-space models a + T=2 defining G 0 = diag b + G=2 hc > R 1 Ci = diagfi + W 0 1 hc > R 1 i = diagfi + W 0 1 UG 0 hai = S > (diagff + W ) 1 ha > Ai = k (diagff + W ) 1 + hai > hai hp + UG 0 U > diagfi + W 0 1 i P Ω ff P T > T 1 Ω ff P > T where S = t=2 xt 1 x t, W = t=1 xt x t, U i = t=1 y tihx > t i and W 0 = W + Ω x T x T > ff. Kalman smoothing propagation Forward recursion: ± Λ t = ± 1 t Backward recursion: + ha > Ai 1 μ t = ± t hc> R 1 iy t + hai± Λ t 1 ± 1 t 1 μ t 1 ± t = I + hc > R 1 Ci hai± Λ t 1 hai>λ 1 t = Ψ t h± 1 t μ t + hai > Ψ 1 t+1 + hai> 1 hai±λ t Ψ 1 Ψ t = h ± Λ 1 t hai > Ψ 1 t+1 t+1 hai±λ t ± 1 t Λ μ t Λ t+1 + hai±λ t hai> 1 hai i 1 cov(x t ; x t+1 ) = Ψ 1 t+1 + hai±λ t hai> hai > ± Λ t haiλ 1