School of Computer Science Probabilistic Graphical Models Variational Inference II: Mean Field Method and Variational Principle Junming Yin Lecture 15, March 7, 2012 X 1 X 1 X 1 X 1 X 2 X 3 X 2 X 2 X 3 X 3 Reading: X 4 X 4 1
Recap Loopy belief propagation (sum-product) algorithm is a method to find the stationary point of Bethe free energy based on direct approximation of Gibbs free energy will revisit BP and Bethe approximation from another point of view later Today, we will look at another approximation inference method based on restricting the family of approximation distribution 2
Variational Methods Variational : fancy name for optimization-based formulations i.e., represent the quantity of interest as the solution to an optimization problem approximate the desired solution by relaxing/approximating the intractable optimization problem Examples: Courant-Fischer for eigenvalues: Linear system of equations: variational formulation: x = arg min x Ax = b, A max(a) = max kxk 2 =1 xt Ax 1 2 xt Ax b T x 0,x = A 1 b for large system, apply conjugate gradient method 3
Inference Problems in Graphical Models Undirected graphical model (MRF): The quantities of interest: p(x) = 1 Z marginal distributions: Y C2C C(x C ) p(x i )= X x j,j6=i normalization constant (partition function): p(x) Question: how to represent these quantities in a variational form? Z 4
Variational Formulation KL(Q P) =!H Q (X)!" E Q log! C (x C ) + log Z C F(P,Q) Gibbs Free Energy F(P, P) =!log Z is a complicated function of true marginals and hard to compute Idea: construct a F(P,Q) such that it has a nice functional form of beliefs (approximate marginals) and easy to optimize F! (P,Q) approach 1: directly approximate with, e.g. Bethe approximation F (P, Q) ˆF (P, Q) =G Bethe ({q i (x i )}, {q ij (x i,x j )}) Q approach 2: restrict in a tractable class of distributions
Mean Field Methods KL(Q P) =!H Q (X)!" E Q log! C (x C ) + log Z C F(P,Q) Gibbs Free Energy Restrict Q for which HQ is feasible to compute exact objective to minimize tightened feasible set yields a lower bound on the log partition function log Z Q is a simple parameterized approximating distribution free parameters to tune are called variational parameters 6
Naïve Mean Field Completely factorized variational distribution " i!v q(x) = q i (x i ) H Q =! " q i (x i ) logq i (x i ) " i x i 7
Naïve Mean Field Free Energy Consider a pairwise Markov random field p(x) / Y i(x i ) i2v Y (i,j)2e ij(x i,x j ) Naïve mean field free energy F (P, Q) =G MF (q) = X (i,j)2e + P P i2v X x i,x j q i (x i )q j (x j ) log ij (x i,x j ) x i q i (x i ) log q i (x i ) X X i2v x i q i (x i ) log i (x i ) Use coordinate descent to optimize with respect to q 8
Naïve Mean Field for Ising Model Ising model in {0,1} representation ( P p(x) / exp x i i + P ) x i x j ij i2v (i,j)2e The true marginals are the mean parameters µ i = p(x i = 1) = E p (x i ) The naïve mean field update equations q i i + P! ij q j j2n(i) q i := q i (x i = 1) = E q [x i ] is the variational mean parameter at node i the variational mean parameters are coupled among neighbors
Derivation G MF (q) = X q i q j ij X q i i (i,j)2e i2v + X i2v (q i log q i +(1 q i ) log(1 q i )) dg MF (q) dq i = X j2n(i) Setting to zero gives us i + That is the mean field equation 0 q i = @ i + X q j ij i + log q i log(1 q i ) X j2n(i) q j ij = log j2n(i) 1 q j ij A q i 1 q i 10
Structured Mean Field Mean field theory is general to any tractable sub-graphs Naïve mean field is based on the fully unconnected sub-graph Variants based on structured sub-graphs can be derived, such as trees, chains, and etc. q(x) / Y C2C q C (x C ) 11
Factorial HMM (Ghahramani & Jordan 97 ) Can be used to model multiple independent latent processes Exact inference is in general intractable (why?) Complexity: O(TMK M+1 ), which is exponential in the # of chains M
Structured Mean Field for Factorial HMM Structured mean field approximation (with variational parameter ) Q({S t } ) / MY m=1 Q(S (m) 1 ) TY t=2 Q(S (m) t S (m) t 1, ) The variational entropy term decouples into sum: one term for each chain In contrast to completely factorized Q, optimizing w.r.t. needs to run forward-backward algorithm as a subroutine 13
Summary so far Mean field methods minimizes KL divergence of variational distribution and target distribution by restricting the class of variational distributions It yields a lower bound of the log partition function, hence is a popular method to implement the approximate E-step of EM algorithm 14
Variational Principle 15
Inference Problems in Graphical Models Undirected graphical model (MRF): The quantities of interest: p(x) = 1 Z marginal distributions: Y C2C C(x C ) p(x i )= X x j,j6=i normalization constant (partition function): p(x) Question: how to represent these quantities in a variational form? Z Use tools from (1) exponential families; (2) convex analysis 16
Exponential Families Canonical parameterization (w.r.t measure ) Canonical Parameters Sufficient Statistics Log partition Function Log normalization constant: A( ) = log Z exp{ T (x)}dx it is a convex function (Prop 3.1 in Wainwright & Jordan) Effective canonical parameters: Regular family:
Examples: Family X ν log p(x; θ) A(θ) Bernoulli {0, 1} Counting θx A(θ) log[1 + exp(θ)] Gaussian R Lebesgue θ 1 x + θ 2 x 2 A(θ) 1 2 [θ 1 + log 2πe ] θ 2 Exponential (0, + ) Lebesgue θ ( x) A(θ) log θ Poisson {0, 1, 2...} Counting θx A(θ) exp(θ) h(x) = 1/x!
Graphical Models as Exponential Families Undirected graphical model (MRF): p(x; ) = 1 Z( ) Y C2C (x C ; C ) MRF in an exponential form: ( X p(x; ) =exp log (x C ; C ) C2C log Z( ) ) log (x C ; C ) can be written in a linear form after some reparameterization Sufficient statistics must respect the structure of graph 19
Example: Hidden Markov Model θ 23 (x 2, x 3 ) x 1 x 2 x 3 x 4 x 5 { θ 5 (x 5 ) y 1 y 2 y 3 y 4 y 5 X { { 1 if xs = j 1 if xs = j and I s;j (x s )= I st;jk (x s,x t )= 0 otherwise, 0 otherwise, x t = k, What are the sufficient statistics? What are the corresponding { the associated canonical parameters? R. Viewed as an st;jk = log P (x t = k x s = j) s;j = log P (y s x s = j) A compact form st (x s,x t )= X jk st;jk I st;jk (x s,x t ) = log P (x t x s )
Example: Discrete MRF θ t (x t ) θ st (x s, x t ) θ s (x s ) Indicators: I j (x s ) = 8 < 1 if x s = j : 0 otherwise nts Parameters: θ s = {θ s;j, j X s } θ st = {θ st;jk, (j, k) X s X t } Compact form: θ s (x s ) := P P j θ s;ji j (x s ) θ st (x s, x t ) := P j,k θ st;jki j (x s )I k (x t ) In exponential form p(x; θ) exp X θ s (x s ) + X θ st (x s, x t ) Xs V (s,t) E X Why is this representation is useful? How is it related to inference problem? Computing the expectation X of sufficient X statistics X(mean parameters) given the canonical parameters yields the marginals 21
Example: Gaussian MRF Consider a zero-mean multivariate Gaussian distribution that respects the Markov property of a graph Hammersley-Clifford theorem states that the precision matrix also respects the graph structure = 1 Gaussian MRF in exponential form 1 p(x) =exp, xx T A( ), where = 2 Sufficient statistics are {x 2 s,s2 V ; x s x t, (s, t) 2 E} 22
Computing Mean Parameter: Bernoulli A single Bernoulli random variable p(x; ) =exp{ x A( )},x2 {0, 1},A( ) = log(1 + e ) Computing its mean parameter from canonical parameter: µ = p(x = 1) = E[x] = e 1+e Want to do it in a variational manner: cast the procedure of computing mean in an optimization-based formulation
Conjugate Dual Function Given any function, its conjugate dual function is: f( ) f (µ) =sup {h,µi Conjugate dual is always a convex function: pointwise supremum of a class of linear functions Under some technical condition on (convex and lower semicontinuous), the dual of dual is itself: f =(f ) f f( )} f( ) =sup µ {h,µi f (µ)} See Convex Optimization book by Boyd for more details
Computing Mean Parameter: Bernoulli Compute the conjugate ˆ A (µ) := sup θ R µθ log[1 + exp(θ)]. µ = e 1+e Stationary condition 8 < We find A µ log µ + (1 µ) log(1 µ) if µ [0, 1] (µ) = : + otherwise.. The variational form to compute mean: A(θ) = max µ [0,1] µ θ A (µ). The optimum is achieved at µ = e 1+e
Next Step The last identity is not a coincidence but a deep theorem in general exponential family However, for general graph models/exponential families, computing the conjugate dual (negative entropy) is intractable Moreover, the constrain set of mean parameter is hard to characterize Relaxing/Approximating them leads to different algorithms: loop belief propagation, naïve mean field, and etc. 26