LDA with Amortied Inference Nanbo Sun Abstract This report describes how to frame Latent Dirichlet Allocation LDA as a Variational Auto- Encoder VAE and use the Amortied Variational Inference AVI to optimie it. 1. We will introduce the LDA and use Mean Field Variational Inference MFVI to optimie it. 2. We collapse the topics in LDA model because we can not do backpropagation through a categorical variable. 3. We will introduce Gaussian VAE with AVI. 4. We frame the collapsed LDA as VAE and do AVI. 1 Prerequisite Before reading this report, you must read the following papers carefully. 1. Blei, D. M., Ng, A. Y., & Jordan, M. I. 2003. Latent dirichlet allocation. Journal of machine Learning research, 3Jan, 993-1022. 2. ingma, D. P., & Welling, M. 2013. Auto-encoding variational bayes. arxiv preprint arxiv:1312.6114. 3. Figurnov, M., Mohamed, S., & Mnih, A. 2018. Implicit Reparameteriation Gradients. arxiv preprint arxiv:1805.08498. 2 LDA with MFVI 2.1 Generative Process For document d with topics and V unique words, 1. draw a topic mixture θ Dirα, where θ is a -vector, and k θ k 1; 2. for each of the N m word counts, independently a draw a topic n Multθ; b draw a word w n p w n n, β, where β is a V matrix, each row of which defines a multinomial distribution over all the voxels. β ij is the probability that word j appears given topic i. 1
α θ d d,n w d,n N d D β Figure 1: LDA model. γ d θ d d,n d,n N d D Figure 2: Variational distribution used to approximate the posterior in LDA. 2.2 Constructing the Lower Bound From Figure 2, the variational distribution used to approximate the true posterior is factoriable as q θ, γ, q θ γ N q n n. The lower bound L γ, α, β of the single-document 1 log likelihood log p w α, β is computed using Jensen s inequality as follows log p w α, β log p θ,, w α, β dθ log p θ,, w α, β q θ, γ, dθ q θ, γ, p θ,, w α, β log q θ, γ, dθ q θ, γ, { p θ,, w α, β log E q q θ, γ, E q {log p θ,, w α, β E q {log q θ, γ, L γ, α, β. The difference between the log likelihood and its lower bound can be proven to be in fact the L divergence between the variational distribution and the true posterior. log p w α, β L γ, α, β E q {log p w α, β E q {log p θ,, w α, β + E q {log q θ, γ, { p w α, β q θ, γ, E q log p θ,, w α, β 1 This also explains why the document subscript is dropped for simplicity hereafter. 1 2
{ q θ, γ, E q log p θ, w, α, β D L q θ, γ, p θ, w, α, β. Therefore, maximiing the lower bound is equivalent to minimiing the L divergence. That is, the variational distribution automatically approaches to the real posterior as we maximie the lower bound. 2.3 Expanding the Lower Bound To maximie the lower bound, we first need to spell out the lower bound L γ, α, β in terms of the model parameters α, β and the variational parameters γ,. Continuing from 1, we have L γ, α, β E q {log p θ,, w α, β E q {log q θ, γ, { p θ,, w α, β E q log q θ, γ, { p θ α p θ p w, β E q log q θ γ q E q {log p θ α + E q {log p θ + E q {log p w, β E q {log q θ γ E q {log q. 2 We now further expand each of the five terms in 2. The first term is Γ E q {log p θ α E q log i1 i α Γ α θ α k 1 k k E q {log Γ α i + α k 1 log θ k log Γ α k i1 i1 α k 1E q {log θ k + log Γ α i log Γ α k α k 1 Ψγ k Ψ γ i + log Γ α i log Γ α k, where Ψ is the digamma function, the first derivative of the log Gamma function. The final line is due to the following property of the Dirichlet distribution as a member of the exponential family. If θ Dirα, then E pθ α {log θ i Ψα i Ψ i1 α i. The second term is N E q {log p θ E q {log p n θ E q {log i1 N θ 1n,k k 3 i1
N N N E q {1 n, k log θ k E q {1 n, k E q {log θ k n,k Ψγ k Ψ γ i, where n,k is the probability of the nth word being produced by topic k, and 1 is the indicator function. We expand the third term as E q {log p w, β E q {log E q {log E q { N N N p w n n, β N V i1 β 1n,k1wn,v k,v v1 v1 v1 N v1 V 1 n, k1 w n, v log β k,v V E q {1 n, k 1 w n, v log β k,v V n,k 1 w n, v log β k,v. Very similar to the first term, the fourth term is expanded as E q {log q θ γ γ k 1 Ψγ k Ψ γ i + log Γ γ k log Γ γ k. Finally, the fifth term is expanded as i1 E q {log q E q {log E q {log N N N q n n N 1n,k n,k E q {1 n, k log n,k n,k log n,k. Therefore, the fully expanded lower bound is L γ, α, β α k 1 Ψγ k Ψ γ i + log Γ α i log Γ α k 4 i1 i1
N + n,k Ψγ k Ψ γ i + N v1 i1 V n,k 1 w n, v log β k,v 3 γ k 1 Ψγ k Ψ γ i log Γ γ k + log Γ γ k i1 N n,k log n,k. 2.4 Maximiing the Lower Bound In this section, we maximie the lower bound w.r.t. the variational parameters and γ. Recall that as the maximiation runs, the L divergence between the variational distribution and the true posterior drops E-step of the variational EM algorithm. 2.4.1 Variational Multinomial We first maximie Equation 3 w.r.t. n,k. Since n,k 1, this is a constrained optimiation problem that can be solved by the Lagrange multiplier method. The Lagrangian w.r.t. n,k is L n,k n,k Ψγ k Ψ γ i + n,k log β k,v n,k log n,k + λ n n,i 1, i1 where λ n is the Lagrange multiplier. Taking the derivative, we get L n,k Ψγ k Ψ γ i + log β k,v log n,k 1 + λ n. n,k i1 Setting this derivative to ero yields n,k β k,v exp Ψγ k Ψ γ i + λ n 1 2.4.2 Variational Dirichlet i1 β k,v exp Ψγ k Ψ γ i. Now we maximie Equation 3 w.r.t. γ k, the kth component of the Dirichlet parameter. Only the terms containing γ k are retained. L γ α k 1 Ψγ k Ψ γ i i1 5 i1 i1
+ N n,k Ψγ k Ψ γ i i1 γ k 1 Ψγ k Ψ γ i log Γ γ i + log Γ γ k i1 Taking the derivative w.r.t. γ k, we have L γ γ Ψ γ k Ψ γ i α k 1 k i1 N + Ψ γ k Ψ γ i i1 n,k i1 Ψ γ k Ψ γ i γ k 1 Ψγ k Ψ γ i i1 Ψ i1 γ i + Ψγ k Γ i1 γ Γ γ i k Ψ γ k Ψ γ i α k + i1 Ψ γ i + Ψγ k i1 Setting it to ero, we have Ψ γ k Ψ γ i α k + i1 2.5 Estimating Model Parameters γ k α k + i1 N n,k γ k Ψγ k + Ψ γ i N n,k γ k. N n,k. The previous section is the E-step of the variational EM algorithm, where we tighten the lower bound w.r.t. the variational parameters; this section is the M-step, where we maximie the lower bound w.r.t. the model parameters α and β. Now we add back the document subscript to consider the whole corpus. i1 By the assumed exchangeability of the documents, the overall log likelihood of the corpus is just the sum of all the documents log likelihoods, and the overall lower bound is just the sum of the individual lower bounds. Again, only the terms involving β are left in the overall lower bound. Adding the Lagrange multipliers, we obtain D N d V V L β d,n,k 1 w d, n, v log β k,v + λ k β k,v 1. d1 v1 v1 6
Taking the derivative w.r.t. β k,v and setting it to ero, we have d1 β k,v L β D N d d1 β k,v 1 λ k d,n,k 1 w d, n, v 1 β k,v + λ k 0 β k,v D N d d,n,k 1 w d, n, v d1 N d D d,n,k 1 w d, n, v. d1 Similarly, for α, we have D L α α k 1 Ψγ d,k Ψ γ d,i + log Γ α i log Γ α k α k L α i1 D Ψγ d,k Ψ γ d,i + Ψ α i Ψα k d1 i1 D Ψγ d,k Ψ γ d,i + D Ψ α i Ψα k. d1 i1 Since the derivative also depends on other α k k, we compute the Hessian i1 i1 i1 2 L α k α α DΨ α i Dδk k Ψα k, k and notice that its form allows for the linear-time Newton-Raphson algorithm. i1 3 Collapsed LDA without topics 3.1 Generative Process For document d with topics and V unique words, 1. draw a topic mixture θ Dirα, where θ is a -vector, and k θ k 1; 2. for each of the N m word counts, independently a draw a word w n p w n θ, β, where β is a V matrix, each row of which defines a multinomial distribution over all the voxels. β ij is the probability that word j appears given topic i. 7
α θ d w d,n N d D β Figure 3: Collapsed LDA model. γ d θ d Figure 4: Variational distribution used to approximate the posterior in collapsed LDA. D 3.2 Comparing LDA vs collapsed LDA Table 1: Comparing LDA vs collapsed LDA. LDA collapsed LDA prior pθ Dirα pθ Dirα likelihood pw k, β Catβ k pw θ, β Catθβ posterior p, θ w pθ w approximate posterior q, θ, γ q qθ γ CatDirγ qθ γ Dirγ Proof of likelihood for collapsed LDA pw dn θ d, β p dn, w dn θ d, β dn p dn θpw dn dn, β dn k θ dk β kw 8
4 Gaussian VAE with AVI 4.1 Cost Function θ x N Figure 5: SVI θ x N Figure 6: VAE circles shaded circles unshaded circles N θ random variables observed random variables hidden random variables number of samples generative model parameters variational approximation parameters L θ, ; x i E q x i log q x i + log p θ, x i log p θ x i D L q x i p θ x i lower bound D L q x i p θ, x i joint-constrastive E q x i log p θ x i D L q x i p θ prior-contrastive Reconstruction - Regulariation L θ, ; x i is the variational lower bound on the marginal log likelihood of the data point x i. The sum N i1 L θ, ; xi is the evidence lower bound objective ELBO. 9
4.2 Optimiation We want to maximie L θ, ; x i in order to maximie the marginal likelihood log p θ x 1,..., x N of the data. Therefore, we need to differentiate and optimie L θ, ; x i with respect to θ and. Differentiate with respect to θ θ L θ, ; x i θ E q x i log q x i + log p θ, x i E q x i θ log q x i + θ log p θ, x i differentiate inside expectation E q x i 0 + θp θ, x i p θ, x i Differentiate with respect to L θ, ; x i E q x i log q x i + log p θ, x i E q x i log q x i + log p θ, x i can t move inside expectation 4.2.1 Score Function Estimators Idea Differentiate the density q E q f fq d f q d differentiate the density q f q log q d REINFORCE trick f log q q d E q f log q score function estimator 1 L f l log q l L Monte Carlo approximation l1 This estimator has been used in LDA with MFVI or SVI. 4.2.2 Pathwise Gradient Estimators 4.2.2.1 Explicit Reparametriation Gradients Idea 10
Differentiate the function f Apply a standardiation function S ɛ qɛ to remove dependence of q on S should be continuously differentiable w.r.t. parameters and argument and invertible S 1 ɛ E q f fq d fs 1 ɛq S 1 ɛds 1 ɛ dɛ integration by subsititution ɛ dɛ fs 1 ɛqɛdɛ qɛ q S 1 ɛds 1 ɛ ɛ dɛ E qɛ fs 1 ɛ E qɛ fs 1 ɛ S 1 ɛ chain rule 4.2.2.2 Implicit Reparametriation Gradients Idea S ɛ S + S 0 S 1 S take the gradient and chain rule E q f ɛ ɛ fq d fs 1 ɛq S 1 ɛds 1 ɛ dɛ dɛ integration by subsititution fs 1 ɛqɛdɛ qɛ q S 1 E qɛ fs 1 ɛ E qɛ fs 1 ɛ S 1 ɛ E qɛ fs 1 ɛ E qɛ fs 1 ɛ S 1 S f S 1 S qɛdɛ ɛ f S 1 S q d ɛ E q f S 1 S 11 ɛds 1 dɛ ɛ chain rule
4.2.2.3 Comparision between explicit and implicit 4.3 Model Table 2: Comparing LDA vs collapsed LDA vs Gaussian VAE. LDA collapsed LDA VAE prior pθ Dirα pθ Dirα p θ N 0, I likelihood pw k, β Catβ k pw θ, β Catθβ p θ x N µ, σ 2 I posterior p, θ w pθ w p θ x approximate posterior q, θ, γ q qθ γ CatDirγ qθ γ Dirγ q x N µ, σ 2 I L θ, ; x i E q x i log p θ x i D L q x i p θ L B θ, ; x i 1 L log p θ x i i,l D L q x i p θ L 1 L 1 L 1 L l1 L log N x i ; µ, σ 2 I D L q x i p θ l1 L x i µ 2 D L q x i p θ discarding some constants l1 σ 2 L x i µ 2 l1 σ 2 + 1 2 J 1 + log σj 2 µ 2 j σj 2 j1 MSE + Regulariation 12
Figure 7: Gaussian VAE 4.3.1 Encoder µ i, σ i MLP x i Explicit reparameteriation Implicit reparameteriation ɛ l pɛ N 0, I i,l S 1 ɛl, x i µ i + σ i ɛ l i,l q x i N ; µ i, σ 2i I 4.3.2 Decoder θ without reparameteriation x i p θ x i i,l N x; µ, σ 2 I θ with reparametriation µ, σ MLP x i x i p θ x i i,l N x; µ, σ 2 I 13
5 Collapsed LDA as VAE with AVI 5.1 Model Table 3: Comparing LDA vs collapsed LDA vs Gaussian VAE. LDA collapsed LDA VAE prior pθ Dirα pθ Dirα p θ N 0, I likelihood pw k, β Catβ k pw θ, β Catθβ p θ x N µ, σ 2 I posterior p, θ w pθ w p θ x approximate posterior q, θ, γ q qθ γ CatDirγ qθ γ Dirγ q x N µ, σ 2 I L θ, ; x i E q x i log p θ x i D L q x i p θ L B θ, ; x i 1 L log p θ x i i,l D L q x i p θ L 1 L l1 L log Catw i ; θβ D L Dirθ; γ Dirθ; α l1 Figure 8: LDA VAE 14
5.1.1 Encoder Implicit reparameteriation γ i MLP w i θ i,l q θ w i Dirθ; γ i 5.1.2 Decoder θ without reparameteriation w i p θ w i θ i,l Catw; θβ 15