1 / 29 Statistical Machine Learning Lectures 4: Variational Bayes Melih Kandemir Özyeğin University, İstanbul, Turkey
2 / 29 Synonyms Variational Bayes Variational Inference Variational Bayesian Inference Mean-Field Inference (in certain cases)
3 / 29 Information We would like to measure the amount of information received when a binary variable x {0, 1} is observed. Information: Degree of surprise after observing x. Devise a function h(x) that measures information gained from x.
4 / 29 How should h(x) look like? When we observe two independent binary variables x and y. The information received should be the sum of the individual events. Because independence implies p(x, y) = p(x)p(y), it is suitable to measure information by h(x) = log 2 p(x). Base 2 is arbitrary. When base 2 is used, the measure is called a bit! Negative sign assures that the measure of information is non-negative.
5 / 29 Entropy Expected amount of information for x p(x): H[x] = log 2 p(x)p(x). Note that the case for p(x) = 0 looks degenerate. Handle this by lim p 0 p ln p = 0, hence H[x] = 0.
6 / 29 Example 1 Consider the case where we have four possible states. When they are equally likely, the entropy turns out to be [ H[x] = 4 1 4 log 1 ] 2 = 2 bits. 4
7 / 29 Example 2 Assume we have again four possible states, this time with ( 5 probabilities 8, 1 4, 1 16 16), 1. Then the entropy is H[x] = 5 8 log 5 2 8 1 4 log 1 2 4 2 1 16 log 1 2 16 = 0.42 + 0.5 + 0.5 = 1.42 bits. There is more information in the uniform case!
8 / 29 Measures of information content Measures of information content log 2 p(x) bits ln p(x) nats Distributions that maximize the entropy Discrete uniform Continuous (for a given location and spread) normal!
9 / 29 Entropy in the continuous domain No exact counterpart. Using mean value theorem, we attain that entropy of a continuous density p(x) differs from the term below by ln H[x] = p(x) log p(x)dx. This term is called the differential entropy. Although differential entropy diverges from the exact entropy as 0, it is often used in place of the plain entropy for continuous densities. We will adopt the same convention here.
10 / 29 Relative entropy or KL divergence Suppose for some reason, we need to approximate p(x) by another density q(x), which has some more pleasant properties. The additional information (in nats) required to be conveyed as a result of using q(x) in place of p(x) is ( ) log q(x) log p(x) = log q(x) p(x) = log p(x) q(x). Since x follows p(x), the expected additional information is log p(x) q(x) p(x)dx. This quantity is called relative entropy or Kullback-Leibler divergence and denoted by KL[p q].
Convexity Consider a parametric line aλ + b(1 λ) that passes between points a and b and an arbitrary function f(x). If any line passing between f(a) and f(b) is always above f(x), then f(x) is called a convex function. More formally, if for any a and b the below inequality satisfies f(a)λ + f(b)(1 λ) f(aλ + b(1 λ)), then f(x) is said to be convex. Figure: C. Bishop, Pattern Recognition and Machine Learning, Springer, 2006. 11 / 29
12 / 29 Jensen s inequality We can prove by induction that convexity holds also for more than two points: ( M M λ i f(x i ) f λ i x i ), i=1 such that {x 1,, x M } is a set of points on the function domain and M i=1 λ i = 1 with λ i 0. We can think of {λ i,, λ M } as parameters of a categorial distribution with M states. Hence we can equivalently write i=1 E[f(x)] f(e[x]). This outcome generalizes to continuous variables straightforwardly (use Riemann integration): f(x)p(x)dx f( ) xp(x)dx.
13 / 29 KL divergence is a dissimilarity measure Considering that log x is a convex function, KL[p q] = p(x) log q(x) p(x) dx log p(x) q(x) p(x) dx = 0. }{{} 1 Because log x is a strictly convex function (i.e. equality holds only at intersection points), p(x) = q(x) KL[p q] = 0. Hence, KL divergence is a dissimilarity metric between two densities. Note that KL[p q] KL[q p].
14 / 29 Calculus of variations Typically we have scalars or vectors as variables. Then we operate on mappings from these variables to other entities. For instance in f(x) : R D R, the vector x is our variable of interest and f( ) is a function of it. There are some cases where we take functions as variables of interest and operate on mappings from functions to other entities: F : f(x) R. Such mappings are called functionals. One example is the KL divergence. The branch of mathematics that has functionals in its focus is named as the calculus of variations.
15 / 29 What if we have non-conjugate priors? Assume we are given a data set X = {x 1,, x N } and a Bayesian model X θ N p(x n θ), n=1 θ p(θ). with a non-conjugate prior p(θ) on the set of latent variables wrt likelihood p(x n θ). We are interested in the posterior p(θ X), which does not have a closed-form solution. What shall we do then?
16 / 29 Approximating the posterior Choose a q(θ γ), a density parameterized by γ, and construct an optimization problem to make q(θ γ) as similar as possible to the true posterior p(θ X). But what sort of an optimization problem would be suitable? Hint: Put the pieces together.
17 / 29 How about this? argmin q(θ γ) Did we solve the problem now? KL[p(θ X) q(θ γ)]
18 / 29 How about this? Not quite! KL[p(θ X) q(θ γ)] = p(θ X) log p(θ X) q(θ γ) dθ. The loss function depends on p(θ X), which we do not know. We ended up with the point we started from!
19 / 29 How about the other way around? argmin q(θ γ) KL[q(θ γ) p(θ X)] At least worthwhile going forward. Approximating the posterior by solving this optimization problem is called Variational Bayes! Actually, there are ways to go forward from KL[p(θ X) q(θ γ)] as well by introducing further approximations. This is called Expectation Propagation. We will cover that approach towards the end of the semester.
Variational Bayes KL[q(θ γ) p(θ X)] = = = + q(θ γ) log q(θ γ) dθ p(θ X) }{{} p(θ, X) p(x) q(θ γ) log q(θ γ)p(x) dθ p(θ, X) q(θ γ) log q(θ γ)dθ q(θ γ) log p(x)dθ q(θ γ) log p(θ, X)dθ 20 / 29
21 / 29 Variational Bayes KL[q(θ γ) p(θ X)] = E q(θ γ) [log q(θ γ)] + E q(θ γ) [log p(x)] }{{}}{{} H q(θ γ) [θ] log p(x) E q(θ γ) [log p(θ, X)] Arranging the terms, we get the interesting outcome below log p(x) = E }{{} q(θ γ) [log p(θ, X)] + H q(θ γ) [θ] + KL[q(θ γ) p(θ X)]. }{{}}{{} const L 0 Hence, L is a lower bound to the log of the evidence. Hence it is called the Evidence Lower Bound (ELBO). ELBO equals to the log-evidence iff q(θ γ) = p(θ X).
22 / 29 Variational Bayes: Inference Optimization argmin q(θ γ) KL[q(θ γ) p(θ X)] argmax q(θ γ) L
23 / 29 Inference as optimization Let us take a closer look at the generic form and contemplate on the feasibility of the approach argmax q(θ γ) L = argmax γ = argmax γ { N } E q(θ γ) [log p(x n θ)] + E q(θ γ) [log p(θ)] + H q(θ γ) [θ] n=1 { N } E q(θ γ) [log p(x n θ)] KL q(θ γ) [q(θ γ) p(θ)] n=1 Calculate E q(θ γ) [log p(x n θ)] and look up H q(θ γ) [θ] or alternatively KL q(θ γ) [q(θ γ) p(θ)]. Take the gradient of the ELBO wrt γ and optimize.
24 / 29 Mean-Field Variational Bayes Let us choose q(θ) = i P q(θ i ), where P is a partitioning of the set of all latent variables. Expressing the ELBO in terms of one of the partitions reads { L = q(θ j ) log p(x, θ) } q(θ i )dθ i dθ j i j } {{ } E q(θ)\q(θj )[log p(x,θ)] q(θ j ) log q(θ j )dθ j + const.
25 / 29 Mean-Field Variational Bayes Let us define p E q(θ)\q(θj )[log p(x, θ)] as a new density and fix all the factors except q(θ j ). The ideal q(θ j ) would make KL[q(θ j ) p] = 0. Hence, q(θ j ) E q(θ)\q(θj )[log p(x, θ)] /Z j, }{{} p where Z j = p dθ. Update all partitions individually using this update rule. Iterate until convergence.
Genuine Bayesian linear regression N y X N (y n w T x n, β 1 ), n=1 w α N (w 0, α 1 ), α G(α a 0, b 0 ). Approximate p(w, α X, y) with q(w, α) = q(w)q(α). 26 / 29
Update for q(w) log q(w) = E q(α) [log p(y w, X)] + E q(α) [log p(w α)] + E q(α) [log p(α)] +const }{{} const = β 2 wt X T Xw + βy T Xw E q(α)[α] w T w + const 2 = 1 [ ] 2 wt βx T X + E q(α) [α] I w + βy T Xw + const }{{} a Completing the square, taking the exponent, and normalizing gives the update rule where q(w) N (w m, S) S = [βx T X + a I] 1, m = βsx T y. 27 / 29
28 / 29 Update for q(α) log q(α) = E q(w) [log p(y w, X)] +E q(w) [log p(w α)] }{{} const + E q(w) [log p(α)] + const = 1 2 log α 1 I α }{{} 2 E q(w)[w T w] }{{} D w T w + (a 0 1) log α b 0 α + const ( D ) = 2 + a 0 1 log α ( 1 2 + b 0)α + const Taking the exponent and normalizing gives ( q(α) G α D 2 + a 0, 1 2 wt w + b 0 ).
Handling the first and second moments cov(x, y) = E [(x ] E[x])(y E[y]) T ] = E[xx T E[x]y T E[x]y T + E[x]E[y] T = E[xx T ] E[x]E[y] T E[x]E[y] T + E[x]E[y] T = E[xx T ] E[x]E[y] T Hence, E[xx T ] = cov(x, y) + E[x]E[y] T. Taking the trace of both sides yields [ ] Tr E[xx T ] = Tr[cov(x, y)] + Tr [E[x]E[y] ] T [ ] E Tr[xx T ] = Tr[cov(x, y)] + E [Tr[x]E[y] ] T E[x T x] = Tr[cov(x, y)] + E[x] T E[y]. Consequently, w T w = Tr(S) + m T m. Note ) also that in the first update a = (D/2 + a 0 )/( w T w /2 + b 0. 29 / 29