Integrated Non-Factorized Variational Inference Shaobo Han, Xuejun Liao and Lawrence Carin Duke University February 27, 2014 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 1 / 29
Overview World Graphical Models Posterior Inference 0.6 MCMC 0.6 VB 0.6 INF VB 1 0.55 0.55 0.55 0.5 0.5 0.5 θ 1 θ 1 θ 1 0.45 0.45 0.45 0.4 0.4 0.4 0.35 0.15 0.2 0.25 0.3 0.35 θ 2 MCMC 0.35 0.15 0.2 0.25 0.3 0.35 θ 2 VB 0.35 0.15 0.2 0.25 0.3 0.35 θ 2 Our method S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 2 / 29
For full posterior inference, our method is A fast deterministic alternative to MCMC More accurate than mean-field variational Bayes (VB) S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 3 / 29
Outline Introduction Integrated Nested Laplace Approximation (INLA) Integrated Non-Factorized Variational Bayes (INF-VB) Applications Summary S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 4 / 29
Problem of Interest Consider a general Bayesian hierarchical model Observation model: y p(y x, θ) Latent variables: Hyperparameters: x p(x θ) θ p(θ) S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 5 / 29
Problem of Interest Consider a general Bayesian hierarchical model Observation model: y p(y x, θ) Latent variables: Hyperparameters: x p(x θ) θ p(θ) Posterior inference: p(x, θ y) S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 5 / 29
Problem of Interest Consider a general Bayesian hierarchical model Observation model: y p(y x, θ) Latent variables: x p(x θ) Hyperparameters: θ p(θ) Posterior inference: p(x, θ y) p(x y) p(θ y) p(x θ, y) S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 5 / 29
Problem of Interest Consider a general Bayesian hierarchical model Observation model: y p(y x, θ) Latent variables: x p(x θ) Hyperparameters: θ p(θ) Posterior inference: p(x, θ y) p(x y) p(θ y) + p(x θ, y) S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 5 / 29
Problem of Interest Consider a general Bayesian hierarchical model Observation model: y p(y x, θ) Latent variables: x p(x θ) Hyperparameters: θ p(θ) Posterior inference: p(x, θ y) p(x y) p(θ y) + p(x θ, y) The exact joint posterior p(x, θ y) = p(y, x, θ) p(y) = p(y x, θ)p(x θ)p(θ) p(y x, θ)p(x θ)p(θ)dxdθ can be difficult to evaluate. S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 5 / 29
Approximate Posterior Inference Sampling based methods: Markov chain Monte Carlo (MCMC) Deterministic alternatives: Laplace approximation (LA) Variational inference Expectation propagation (EP) Integrated nested Laplace approximation (INLA) 1 1 Rue et al., 2009 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 6 / 29
Outline Introduction Integrated Nested Laplace Approximation (INLA) Integrated Non-Factorized Variational Bayes (INF-VB) Applications Summary S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 7 / 29
INLA in a Nutshell (1/3) Main idea: Discretizing the low-dimensional space θ using a grid G Demo: θ k G 2 Kass & Steffey, 1989 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 8 / 29
INLA in a Nutshell (1/3) Main idea: Discretizing the low-dimensional space θ using a grid G Demo: q G (x y, θ k ) θ k G 2 Kass & Steffey, 1989 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 8 / 29
INLA in a Nutshell (1/3) Main idea: Discretizing the low-dimensional space θ using a grid G Demo: q G (x y, θ k ) θ k G 1. Laplace approximation 2 : q G (x y, θ k ) = N (x; x (θ k ), H(x (θ k )) 1 ), θ k G where x (θ k ) = argmax x p(x y, θ k ) is the posterior mode, and H(x (θ k )) is the Hessian matrix of the log posterior evaluated at the mode. 2 Kass & Steffey, 1989 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 8 / 29
INLA in a Nutshell (2/3) Main idea: Discretizing the low-dimensional space θ using a grid G Demo: q LA (θ y) q G (x y, θ k ) θ k G 3 Tierney & Kadane, 1986 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 9 / 29
INLA in a Nutshell (2/3) Main idea: Discretizing the low-dimensional space θ using a grid G Demo: According to the Bayes rule, q LA (θ y) q G (x y, θ k ) θ k G p(θ y) = p(x, y, θ), x (1) p(y)p(x y, θ) 3 Tierney & Kadane, 1986 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 9 / 29
INLA in a Nutshell (2/3) Main idea: Discretizing the low-dimensional space θ using a grid G Demo: According to the Bayes rule, q LA (θ y) q G (x y, θ k ) θ k G p(θ y) = p(x, y, θ), x (1) p(y)p(x y, θ) 2. Laplace s method of integration 3 : p(x, y, θ) q LA (θ y) = p(y)q G (x y, θ) x=x (θ) (2) 3 Tierney & Kadane, 1986 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 9 / 29
INLA in a Nutshell (3/3) Main idea: Discretizing the low-dimensional space θ using a grid G Demo: q(x y) q LA (θ y) + q G (x y, θ k ) θ k G S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 10 / 29
INLA in a Nutshell (3/3) Main idea: Discretizing the low-dimensional space θ using a grid G Demo: q(x y) q LA (θ y) + q G (x y, θ k ) θ k G 3. Numerical integration: q(x y) = k q G (x y, θ k )q LA (θ k y) k with area weights k. S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 10 / 29
INLA in a Nutshell (3/3) Main idea: Discretizing the low-dimensional space θ using a grid G Demo: q(x, θ y) q(x y) q LA (θ y) + q G (x y, θ k ) θ k G 3. Numerical integration: q(x y) = k q G (x y, θ k )q LA (θ k y) k with area weights k. S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 10 / 29
INLA in a Nutshell (3/3) Benefits: 1. Preserves full posterior dependencies (i.e. joint density q(x, θ y)) 2. Computationally efficient (MCMC: hours or days, INLA: seconds or minutes) Limitations: 1. Applies only to latent Gaussian models (LGMs) 2. No quantization for the accuracy of approximation q(x, θ y) 3. The dimension of θ has to be no more than 5 or 6 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 11 / 29
INLA in a Nutshell (3/3) Benefits: 1. Preserves full posterior dependencies (i.e. joint density q(x, θ y)) 2. Computationally efficient (MCMC: hours or days, INLA: seconds or minutes) Limitations: 1. Applies only to latent Gaussian models (LGMs) 2. No quantization for the accuracy of approximation q(x, θ y) 3. The dimension of θ has to be no more than 5 or 6 Our method addresses the first two limitations with INLA. S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 11 / 29
Outline Introduction Integrated Nested Laplace Approximation (INLA) Integrated Non-Factorized Variational Bayes (INF-VB) Applications Future Research S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 12 / 29
Variational Inference Variational inference turns Bayesian inference into optimization. min KL[q(x, θ y) p(x, θ y)] s.t. q(x, θ y) Q (3) q(x,θ y) Evidence lower bound (ELBO): Applying Jensen s inequality, p(y, x, θ) ln p(y) = ln q(x, θ y) q(x, θ y) dxdθ p(y, x, θ) q(x, θ y)ln dxdθ := L (4) q(x, θ y) The Jensen s gap: ln p(y) L = KL(q(x, θ y) p(x, θ y)) The variational distribution q(x, θ y) is commonly restricted to tractable families Q S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 13 / 29
Mean-Field Variational Bayes (VB) Assumes factorized form: q(x, θ y) = q(x)q(θ), then Remarks: q (x, θ y) = argmin KL(q(x, θ y) p(x, θ y)) q(x,θ y) = argmin q(x),q(θ) q(x)q(θ) ln q(x)q(θ) p(x, θ y) dxdθ Easily derived and in close form for conjugate models Challenging for non-conjugate models Ignores posterior dependencies and impairs the accuracy A poor approximation for a multi-modal distribution S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 14 / 29
Mean-Field Variational Bayes (VB) Assumes factorized form: q(x, θ y) = q(x)q(θ), then Remarks: q (x, θ y) = argmin KL(q(x, θ y) p(x, θ y)) q(x,θ y) = argmin q(x),q(θ) q(x)q(θ) ln q(x)q(θ) p(x, θ y) dxdθ Easily derived and in close form for conjugate models Challenging for non-conjugate models Ignores posterior dependencies and impairs the accuracy A poor approximation for a multi-modal distribution Our non-factorized variational method addresses these issues with mean-field VB. S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 14 / 29
Hybrid Continuous-Discrete Family Consider non-factorized form: q(x, θ y) = q(x y, θ)q d (θ y) (5) x and θ are still coupling 1. The continuous approximation q(x y, θ) is very flexible Gaussian Mean-Field 2. The discretized approximation q d (θ y) is a finite mixture of Dirac-delta distributions, q d (θ y) = ω k δ θk (θ), ω k = q d (θ k y), ω k = 1 (6) k k S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 15 / 29
Proposed Method Within the proposed hybrid family, the optimal variational distribution is q (x, θ y) = argmin KL(q(x, θ y) p(x, θ y)) q(x,θ y) = argmin q(x y,θ),q d (θ y) = argmin q(x y,θ k ),q d (θ k y) q(x y, θ)q d (θ y) ln q(x y, θ)q d(θ y) dxdθ p(x, θ y) q(x y, θ k )q d (θ k y) ln q(x y, θ k)q d (θ k y) dx p(x, θ k y) k We give the name integrated non-factorized variational Bayes (INF-VB) to this method. S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 16 / 29
Computation Variational Optimization Algorithm Step 1 (Local): For each θ k G, independently solving, q (x y, θ k ) = argmin KL(q(x y, θ k ) p(x y, θ k )) (7) q(x y,θ k ) Step 2 (Global): Given {q (x y, θ k ) : θ k G}, one have ( qd (θ k y) exp q (x y, θ k ) ln p(x, θ ) k y) q (x y, θ k ) dx (8) INF-VB is parallelizable, with dominant computational load distributed on each grid point INF-VB requires no iteration between Step 1 and Step 2 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 17 / 29
Our approach unifies INLA under the variational framework. S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 18 / 29
INLA v.s. INF-VB Main idea: Discretizing the low-dimensional space θ using a grid G 1. Gaussian approximation q G (x y, θ k ) = N (x; x (θ k ), H(x (θ k )) 1 ), θ k G 2. Hyperparameter learning q LA (θ y) p(x, y, θ) q G (x y, θ) x=x (θ) 3. Marginal posterior of x q(x y) = k q G (x y, θ k )q LA (θ k y) k with area weights k. S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 19 / 29
INLA v.s. INF-VB Main idea: Discretizing the low-dimensional space θ using a grid G 1. Gaussian approximation q G (x y, θ k ) = N (x; x (θ k ), H(x (θ k )) 1 ), θ k G (INF-VB) Step 1: Variational Gaussian approximation q V G (x y, θ k) = argmin KL(q(x y, θ k ) p(x y, θ k )), θ k G 2. Hyperparameter learning q LA (θ y) p(x, y, θ) q G (x y, θ) x=x (θ) 3. Marginal posterior of x q(x y) = k q G (x y, θ k )q LA (θ k y) k S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 19 / 29
INLA v.s. INF-VB Main idea: Discretizing the low-dimensional space θ using a grid G (INF-VB) Step 1: Variational Gaussian approximation q V G (x y, θ k) = argmin KL(q(x y, θ k ) p(x y, θ k )), θ k G 2. Hyperparameter learning q LA (θ y) p(x, y, θ) q G (x y, θ) x=x (θ) (INF-VB) Step 2: ( qd (θ k y) exp q V G(x y, θ k ) ln p(x, θ ) k y) qv G (x y, θ k) dx 3. Marginal posterior of x q(x y) = k q G (x y, θ k )q LA (θ k y) k S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 19 / 29
INLA v.s. INF-VB Main idea: Discretizing the low-dimensional space θ using a grid G (INF-VB) Step 1: Variational Gaussian approximation q V G (x y, θ k) = argmin KL(q(x y, θ k ) p(x y, θ k )), θ k G (INF-VB) Step 2: Hyperparameter ( learning qd (θ k y) exp qv G(x y, θ k ) ln p(x, θ ) k y) qv G (x y, θ k) dx 3. Marginal posterior of x q(x y) = k q G (x y, θ k )q LA (θ k y) k (INF-VB) Step 3: q(x y) = q(x y, θ)q d (θ y)dθ = k q V G(x y, θ k )q d (θ k y) S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 19 / 29
INLA v.s. INF-VB Main idea: Discretizing the low-dimensional space θ using a grid G (INF-VB) Step 1: Variational Gaussian approximation q V G (x y, θ k) = argmin KL(q(x y, θ k ) p(x y, θ k )), θ k G (INF-VB) Step 2: Hyperparameter learning ( qd (θ k y) exp qv G(x y, θ k ) ln p(x, θ ) k y) qv G (x y, θ k) dx (INF-VB) Step 3: Marginal posterior of x q(x y) = q(x y, θ)q d (θ y)dθ = k q V G(x y, θ k )q d (θ k y) S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 19 / 29
Remarks Benefits: Applicable to more general scenarios Optimal variational distributions q(x y, θ k ) and q d (θ y) Negative ELBO provides quantitization of the accuracy Limitations: The dimension of θ has to be no more than 5 or 6 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 20 / 29
Application to Bayesian Lasso 1. Non-differentiability of the l 1 norm 2. The Laplace approximation of INLA cannot be applied S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 21 / 29
Bayesian Lasso Regression 4 (1/3) Model: y = Φx + e, e N (e; 0, σ 2 I n ) where y R n, Φ R n p, and e R n. We assume x j σ 2, λ 2, λ ( 2 σ exp λ ) x j 1 2 σ 2 σ 2 InvGa(σ 2 ; a, b) λ 2 Ga(λ 2 ; r, s). Problem: Given y and Φ, find posterior distributions for x and θ = {λ 2, σ 2 } 4 Park & Casella, 2008 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 22 / 29
Bayesian Lasso Regression (2/3) Inference: 1. Data augmentation Gibbs sampler 2. Mean-Field VB 3. INF-VB INF-VB for Bayesian Lasso (1) q (x y, θ k ): constrain q(x y, θ) = N (x; µ, CC T ), then KL(q(x y, θ) p(x y, θ)) := g(µ, C) (9) is concave in (µ, C) a, D = CC T. (2) q (θ y): can be evaluated analytically (3) q (x y): finite mixture of Gaussians a Challis & Barber, 2011 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 23 / 29
Bayesian Lasso Regression (3/3) Denote (µ, D ) = argmin µ,d g(µ, D), the variational Bayesian Lasso, µ ) = argmin g(µ), g(µ) := E N (x;µ,d )( y Φx 2 2 + 2λσ x 1 (10) µ is a counterpart of Lasso 5, Remarks: ˆx = argmin f(x), f(x) = y Φx 2 2 + 2λσ x 1 (11) x The conditions of Lasso hold on average Smoothing around origin and thus differentiable Optimize a non-differential function by operating on a sequence of differentiable functions 5 Tibshirani, 1996 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 24 / 29
Results (1/4): Diabetes Dataset 6 This benchmark dataset contains Measurements on n = 442 diabetes patients p = 10 clinical covariates (age, sex, body mass index, average blood pressure, and six blood serum measurements) Response variable, a quantitative measure of disease progression Goal: Identify which covariates are important factors Methods: Intensive MCMC runs (ground truth) Mean-Field VB INF-VB-1 INF-VB-2 (INLA, replace LA with VG) Ordinary least square (OLS) 6 Efron et al., 2004 S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 25 / 29
Results (2/4): Marginal Posteriors q(x j y) q(x 2 y) 15 10 MCMC INF VB 1 INF VB 2 VB OLS q(x 4 y) 20 15 10 MCMC INF VB 1 INF VB 2 VB OLS 5 5 0 0.15 0.1 0.05 0 0.05 x (sex) 2 20 15 (a) MCMC INF VB 1 INF VB 2 VB OLS 0 20 15 0.1 0.05 0 0.05 0.1 x (bp) 4 (b) MCMC INF VB 1 INF VB 2 VB OLS q(x 9 y) 10 q(x 10 y) 10 5 5 0 0.1 0.05 0 0.05 0.1 0 0.1 0.05 0 0.05 0.1 x (ltg) 9 x (glu) 10 (c) (d) S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 26 / 29
Results (3/4): Marginal Posteriors q(σ 2 y) and q(λ 2 y) q(σ 2 y) 6 5 4 3 2 1 MCMC INF VB 1 INF VB 2 VB OLS q(λ 2 y) 2.5 x 10 3 2 1.5 1 0.5 MCMC INF VB 1 INF VB 2 VB OLS 0 0.8 1 1.2 1.4 σ 2 (a) 0 0 1000 2000 3000 4000 λ 2 (b) Posterior marginals of hyperparameters: (a) q(σ 2 y) and (b) q(λ 2 y) Mean-Field VB could severely underestimate the posterior variance INF-VB-2 offers suboptimal solution S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 27 / 29
Results (4/4): Accuracy and Speed Negative ELBO 665 660 655 650 645 640 635 INF VB 1 INF VB 2 VB 630 0 10 20 30 40 50 m (a) Accuracy Elapsed Time (seconds) 22 20 18 16 14 12 10 8 6 4 2 MCMC INF VB 1 INF VB 2 VB 0 0 10 20 30 40 50 m (b) Time Grid size m m and m = 1, 5, 10, 30, 50. INF-VB with a 1 1 grid: partial Bayesian learning of q(x y, θ) with a fixed θ S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 28 / 29
Summary Our method: 1. Tractable family Q: non-factorized 2. Conditional conjugacy: not required 3. Multimodal posterior: could handle 4. Parallelizable: yes More could be done... S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 29 / 29
Q&A: Accuracy and Speed Negative ELBO 665 660 655 650 645 640 635 INF VB 1 INF VB 2 INF VB 3 INF VB 4 VB Elapsed Time (seconds) 20 15 10 5 MCMC INF VB 1 INF VB 2 INF VB 3 INF VB 4 VB 630 0 10 20 30 40 50 m (a) 0 0 10 20 30 40 50 m (b) In INF-VB-3 and INF-VB-4 (INLA, replace LA with VG), we obtain a fast VG solution by minimizing a KL divergence upper bound Grid size m m and m = 1, 5, 10, 30, 50. S. Han et al. Integrated Non-Factorized Variational Inference February 27, 2014 29 / 29