Variational Message Passing By John Winn, Christopher M. Bishop Presented by Andy Miller
Overview Background Variational Inference Conjugate-Exponential Models Variational Message Passing Messages Univariate Gaussian Example Allowable Models and Constraints VIBES and Extensions
Variational Inference Our model (directed graphical model) X = (V,H ) V are visible variables H are latent variables includes parameters Our dilemma Exact inference algorithms are computationally intractable for all but the simplest models. Our goal Find tractable approximate: Q(H) P(H V)
Variational Inference Note the natural decomposition of log likelihood: where and lnp(v) = (Q) + KL(Q P) (Q) = H Q(H)ln P(H,V) Q(H) H KL(Q P) = Q(H)ln P(H V) Q(H)
Variational Inference For arbitrary Q lnp(v) = (Q) + KL(Q P) fixed maximize minimize Minimize KL(Q P) w.r.t unrestricted Q? We get: Q(H) = P(H V) but this is what we re trying to avoid
Variational Inference Family of distributions explored by Winn: where {H i } are disjoint groups of latent variables Vastly reduces space e.g. assuming a fully disjoint set of discrete variables: Q reduces P space: i Q(H) = Q i (H i ) H = N, H i {1,..,K} K N KN
Variational Inference Plug in factorized Q to lower bound equation: (Q) = Q i (H i ) lnp(h,v) Q i (H i )lnq i (H i ) H i i H i Separate out all terms in one factor Q j (Q) = KL(Q j Q j * ) + terms not in Q j Introduce some new distribution Q* j Minimize this KL divergence
Variational Inference Maximizing the lower bound w.r.t. some factor Q j : lnq j * (H j ) = ln P(H,V) ~Q(H j ) + const. Q * j (H j ) = 1 Z exp ( P(H,V) ) ~Q(H j ) Can see that solutions are coupled, each Q j depends on expectations w.r.t. factors Q i j Variational optimization proceeds by initializing each Q i and then cycling through each factor
Variational Inference (recap) 1. Choose a family, Q(H) of variational distributions: 2. Use Kullback-Leibler divergence, KL(Q P ), as a measure of distance between P (H V ) and Q(H). 3. Find Q that minimizes divergence (or equivalently, maximizes the lower bound).
KL Divergence minkl( p q) minkl(q p) Figures from Pattern Recognition and Machine Learning. Bishop, 2006.
KL Divergence minkl( p q) minkl(q p) Figures from Pattern Recognition and Machine Learning. Bishop, 2006.
Variational Inference in Directed Model Assuming a directed graphical model, full distribution: i P(X) = P(X i pa i ) Winn assumes fully factorized Q i Q(H) = Q i (H i )
Variational Inference in Directed Model Plugging factorized joint into optimized form of factor j: lnq j * = lnp(x i pa i ) i ~Q( H j ) + const. Terms not depending on H j will be constant, yielding: lnq * j (H j ) = ln P(H j pa j ) + ~Q( H j lnp(x ) k pa k ) + ~Q( H const. k ch j j ) Distribution only relies on parents, children and co-parents
Variational Inference in Directed Model For a factorized Q, each update equation relies only on variables in the Markov blanket Can decompose the overall optimization into a set of local computations
Conjugate-Exponential Models Simplify update equations conditional distributions from the exponential family conjugate w.r.t. distributions over parent variables y x A parent distribution p(x y) is said to be conjugate to child distribution p(w x) if p(x y) has the same functional form, with respect to x, as p(w x). p(x w, y) p(w x) p(x y) w same family same functional form
Conjugate-Exponential Models Conditional distributions expressed in exponential family form: T ln P ( X θ ) = θ u ( X ) + g ( θ ) + f ( X ) natural parameter vector sufficient statistics vector E.g. univariate Gausian: ln P ( X µ, γ ) T µγ X = 1 2 2 ln / 2 + γ γµ γ X 2 π + 1 2 2 0
Conjugate-Exponential Models Parents and children are chosen to be conjugate, i.e. have the same functional form ln T ln P ( X θ ) = θ u ( X ) + g ( θ ) + same f ( X ) T P ( Z X, Y ) = φ ( Y, Z ) u ( X ) + g '( X ) + E.g. f '( Y, Z ) Gaussian for the mean of a Gaussian Gamma for the precision of a Gaussian Dirichlet for the parameters of a discrete distribution X Z Y
Conjugate-Exponential Models P(Y X, pa Y ) = P(X Y,cp Y )P(Y pa Y ) same family same form (wrt Y)
Conjugate-Exponential Models lnp(y pa Y ) = φ Y ( pa Y ) Τ u Y (Y) + f Y (Y) +g Y (pa Y ) lnp(x Y,cp Y ) = φ X (Y,cp Y ) Τ u X (X) + f X (X) +g X (Y,cp Y ) = φ XY (X,cp Y ) Τ u Y (Y) + λ(x,cp Y )
Conjugate-Exponential Models lnq Y * (Y) = φ Y ( pa Y ) Τ u Y (Y) +f Y (Y) + g Y (pa Y ) ~Q(Y ) + φ XY (X k,cp k ) Τ u Y (Y) +λ(x k,cp k ) ~Q(Y ) + const. k ch j = φ Y ( pa Y ) ~Q(Y ) + φ XY (X k,cp k ) ~Q(Y ) k ch Y Τ u Y (Y)+f Y (Y) + const.
Conjugate-Exponential Models φ * Y = φ Y ( pa Y ) ~Q(Y ) + φ XY (X k,cp k ) ~Q(Y ) k ch Y φ Y { u i } i pay φ XY u k, u j = φ Y (pa Y ) ~Q(Y ) { } j cpk = φ XY (X k,cp k ) ~Q(Y )
Variational Message Passing Conditional distributions: T ln P ( X θ ) = θ u ( X ) + g ( θ ) + f ( X ) T ln P ( Z X, Y ) = φ ( Y, Z ) u ( X ) + g '( X ) + f '( Y, Z ) Messages: Parent to child (X Z) X Y Child to parent (Z X) Z
Varational Message Passing Optimal Q(X) has the same form as P(X θ), but with updated parameter vector θ* Computed from messages from parents X Y Z
VMP Example: defining the model Learning parameters of a Gaussian from N data points. mean µ γ precision (inverse variance) data x N N n =1 P(x µ,γ 1 ) = N(x n µ,γ 1 )
VMP Example: defining the model Learning parameters of a Gaussian from N data points. mean µ γ precision (inverse variance) data x N u x (x n ) = [x n, x n 2 ] Τ lnp(x µ,γ 1 ) = γµ γ /2 Τ x n 2 x n + 1 2 (lnγ γµ2 ln2π)
VMP Example: defining the model Learning parameters of a Gaussian from N data points. mean µ γ precision (inverse variance) data x N u µ (µ) = [µ,µ 2 ] Τ lnp(x µ,γ 1 ) = γx n γ /2 Τ µ µ 2 + 1 2 (lnγ γx 2 n ln2π)
VMP Example: defining the model Learning parameters of a Gaussian from N data points. mean µ γ precision (inverse variance) data x N lnp(x µ,γ 1 ) = 1 (x 2 n µ)2 1 2 Τ u γ (γ) = [γ,lnγ] Τ γ ln2π lnγ
VMP Example: defining the model Learning parameters of a Gaussian from N data points. Gaussian distribution with hyper params (m, β ) Gamma distribution with hyper params (a, b ) lnp(µ m,β 1 ) = βm β/2 Τ µ µ 2 + 1 2 (lnβ βm2 ln2π) lnp(γ a,b) = b a 1 mean µ γ Τ precision (inverse variance) γ + alnb lnγ(a) lnγ data x N
VMP Example: passing messages Variational Distribution: Q(µ,γ ) = Q(µ)Q(γ) µ γ x N Find initial values: u µ (µ) and u γ (γ )
VMP Example: passing messages Message from γ to all x. µ γ m γ xn = γ lnγ x N
VMP Example: passing messages Messages from each x n to µ. µ γ m xn µ = γ x n γ /2 x N lnp(x µ,γ 1 ) = γx n γ /2 Τ µ µ 2 + 1 2 (lnγ γx 2 n ln2π)
VMP Example: passing messages Update Q(µ) parameter vector N φ * µ = φ µ + m xn µ n =1 = βm N + m β/2 xn µ n =1 µ x N γ
VMP Example: passing messages Message from updated µ to all x. µ γ m µ xn = µ µ 2 x N
VMP Example: passing messages µ x γ N m xn γ = 1 2 x n µ ( ) 2 1 2 Messages from each x n to γ. lnp(x µ,γ 1 ) = 1 2 (x n µ)2 1 2 Τ γ lnγ ln2π
VMP Example: passing messages Update Q(γ) parameter vector µ x N γ φ γ * = φ γ + N m xn γ n =1 = b + a 1 N n =1 m xn γ
VMP Example: converged distribution Q(µ,γ ) = Q µ (µ)q γ (γ) P(µ,γ V)
Features of VMP Guaranteed to converge to a local minimum of KL(Q P) Flexible message passing schedule factors can be updated in any order (thought it may alter convergence) Graph does not need to be a tree (needs to be acyclic)
Allowable Models and Constraints Parent-child edges must satisfy conjugacy Gaussian variable: Gaussian parent for its mean Gamma parent for its precision Gamma variable: Gamma scale parameter b Discrete Variable Dirichlet prior
Allowable Models and Constraints Table 1: Distributions for each parameter of a number of exponential family distributions if the model is to satisfy conjugacy constraints. Conjugacy also holds if the distributions are replaced by their multivariate counterparts e.g. the distribution conjugate to the precision matrix of a multivariate Gaussian is a Wishart distribution. Where None is specified, no standard distribution satisfies conjugacy.
Allowable Models and Constraints Truncated Distributions Incorporates deterministic variables Mixture distributions Multivariate distributions
Allowable Models: Mixture Models Not in the exponential family: K k =1 P(X {π k },{θ k }) = π k P k (X θ k ) Introduce latent variable,, which indicates component P(X λ,{θ k }) = K k =1 P k (X θ k ) δ λk lnp(x λ,{θ k }) = δ(λ,k) [ φ k (θ k ) Τ u k (X) + f k (X) + g k (θ k )] k
Allowable Models: Mixture Models Require that all component distributions have the same natural statistic vector: u X (X) = Δ Can rewrite log conditional: u 1 (X) =... = u K (X) lnp(x λ,{θ k }) = φ X (λ,{θ k }) Τ u X (X) + f X (X) + g (φ X (λ,{θ k })) where φ X = δ( λ,k)φ k (θ k ) k
Allowable Models: Mixture Models Now the messages are defined as: m X θ k = Q(λ = k) φ Xθ k m X Z chx = u X (X) λ m X λ = lnp k (X θ k ) θ K x N Z
Allowable Models General architecture: arbitrary directed acyclic subgraphs of multinomial discrete variables (with Dirichlet priors) Arbitrary subgraphs of univariate and multivariate linear Gaussian nodes (having gamma and Wishart priors) Arbitrary mixture nodes providing connections from discrete to continuous subgraphs Can include deterministic nodes Any continuous distribution can be truncated to restrict range of allowable values Includes: HMMs, Kalman Filters, Factor Analysis, PCA, ICA
VIBES VIBES inspired by WinBUGS Graphically specify models, and run inference
Extensions Can introduce additional variational parameters to use non-conjugate distributions Logistic Sigmoid function can be estimated by Gaussian-like bound Next step would be achieve a posterior estimate with some dependency structure (i.e. structured variational inference)