Message passing and approximate message passing Arian Maleki Columbia University 1 / 47
What is the problem? Given pdf µ(x 1, x 2,..., x n ) we are interested in arg maxx1,x 2,...,x n µ(x 1, x 2,..., x n) Calculating marginal pdf µi(x i) Calculating joint distribution µs(x S) for S {1, 2,..., n} These problems are important because Many inference problems are in these forms Many applications in image procession, communication, machine learning, signal processing This talk: Marginalizing distribution Most tools are applied to other problems 2 / 47
This talk: marginalization Problem: Given pdf µ(x 1, x 2,..., x n ) find µ i (x i ) Simplification: all xj s are binary random variables Is it difficult? Yes. No polynomial time algorithm Complexity of variable elimination: 2 n 1 Solution: No good solution for generic µ Consider structured µ 3 / 47
Structured µ Problem: Given pdf µ(x 1, x 2,..., x n ) find µ i (x i ) Simplification: all xj s are binary random variables Example 1: x 1, x 2,..., x n are independent Easy marginalization Not interesting Example 2: Independent subsets µ(x 1, x 2,..., x n ) = µ 1(x 1,..., x k )µ 2(x k+1,..., x n ) More interesting and less easy Complexity still exponential in terms of size of subsets Does not cover many interesting cases 4 / 47
Structured µ:cont d Problem: Given pdf µ(x 1, x 2,..., x n ) find µ i (x i ) Simplification: all xj s are binary random variables Example 3: more interesting structure: µ(x 1, x 2,..., x n ) = µ 1(x S1 )µ 2(x S2 )... µ l(x Sl ) Unlike independence, Si S j Covers many interesting practical problems Let Si n, i. Is marginalization easier than a generic µ? o Not clear!! We introduce factor graph to address this question 5 / 47
Factor graph Slight generalization µ(x 1, x 2,..., x n ) = 1 Z ψ a (x Sa ) a F 1 Variable nodes Factor nodes ψa : {0, 1} Sa R + not necessarily pdf Z: normalizing constant. Called "partition function" ψa(x Sa ): factors Factor graph: 2.. a b Variable node: n nodes corresponding to x 1, x 2,..., x n n Function node: F nodes corresponding to ψ a(x S1 ), ψ b (x S2 ),... Edge between a and variable node i iff xi S a 6 / 47
Factor graph Cont d µ(x 1, x 2,..., x n ) = 1 Z ψ a (x Sa ) Factor graph: a F Variable node: n nodes corresponding to x 1, x 2,..., x n Function node: F nodes corresponding to ψ a(x S1 ), ψ b (x S2 ),... 1 2 Variable nodes. Factor nodes a b Edge between a and variable node i iff xi S a Some notation:. a: neighbors of function node a: a = {i : x i S a}. n i: neighbors of variable node i i = {b F : i S b }. 7 / 47
Factor graph example µ(x 1, x 2, x 3 ) = 1 Z ψ a(x 1, x 2 )ψ b (x 2, x 3 ) Some notation: a: neighbors of function node a: a = {1, 2} b = {2, 3} i: neighbors of variable node i 1 2 3 Variable nodes Factor nodes a b 1 = {a} 2 = {a, b} 8 / 47
Why factor graph? 1 Variable nodes Factor nodes Factor graphs simplifies understanding structures Independence: graph has disconnected pieces Tree graphs: Easy marginalization Loops (specially short loops): generally make the problem more complicated 2.. a b n 9 / 47
Marginalization on tree structured graphs Tree: Graph with no loops Marginalization by variable elimination is efficient on such graphs (distributions) o Linear in n o Exponential in size of factors i Proof: on the board 10 / 47
Summary of variable elimination on trees Notation µ(x1, x 2,..., x n) = ψa(x a F a) µj a: belief from variable node j to factor a ˆµa j: belief from factor node a to variable j i We also have µ a j (x j ) = ψ a (x a ) ˆµ l a (x l ) ˆµ j a (x j ) = {x j:j a\j} b j\a µ b j (x j ) l a\j 11 / 47
Belief propagation algorithm on trees Based on variable elimination on trees consider iterative algorithm: ν t+1 a j (x j) = ψ a (x a ) ˆν l a(x t l ) ˆν t+1 j a (x j) = {x j:j a\j} b j\a ν t b j(x j ) l a\j i ν t a j and ˆν t j a are beliefs of nodes at time t Each node propagates its belief to other nodes Hope: Converge to a good fixed point νb j(x t j ) µ b j (x j ) t ˆν j b(x t j ) µ j b (x j ) t 12 / 47
Belief propagation algorithm on trees Belief propagation: ν t+1 a j (x j) = ˆν t+1 j a (x j) = {x j:j a\j} b j\a ψ a (x a ) ν t b j(x j ) l a\j ˆν t l a(x l ) i Why BP and not variable elimination? On trees: no good reason loopy graphs: can be easily applied 13 / 47
Belief propagation algorithm on trees Belief propagation: ν t+1 a j (x j) = ˆν t+1 j a (x j) = {x j:j a\j} b j\a ψ a (x a ) ν t b j(x j ) l a\j ˆν t l a(x l ) i Lemma: On a tree graph BP converges to the marginal distributions in finite number of iterations, i.e, νb j(x t j ) µ b j (x j ) t ˆν j b(x t j ) µ j b (x j ) t 14 / 47
Summary of belief propagation on trees i Equivalent to variable elimination Exact on trees Converges in finite number of iterations: 2 times the maximum depth of the tree. 15 / 47
Loopy belief propagation Apply belief propagation to loopy graph: ν t+1 a j (x j) = ψ a (x a ) ˆν l a(x t l ) ˆν t+1 j a (x j) = {x j:j a\j} b j\a ν t b j(x j ) l a\j i Wait until the algorithm converges: Announce a j ν a j(x j) as µ j(x j) 16 / 47
Challenges for loopy BP Does not converge necessarily Example : µ(x 1, x 2, x 3) = I(x 1 x 2)I(x 2 x 3)I(x 3 x 1) o t = 1 start with x 1 = 0 o t = 2 x 2 = 1 o t = 3 x 3 = 0 o t = 4 x 4 = 1 Even if converges, not necessarily the marginal distribution 1 2 3 Can have more than one fixed point 17 / 47
Advantages of loopy BP In many applications, works really well i Coding theory Machine vision compressed sensing. In some cases, loopy BP can be analyzed theoretically 18 / 47
Some theoretical results on loopy BP BP equations have at least one fixed point Proof uses Brouwer s theorem Does not necessarily converge to any of the fixed points BP is exact on trees i BP is accurate for locally tree like graphs Gallager 1963, Luby et al. 2001, Richardson and Urbanke 2001 Several methods exist for verifying correctness of BP Computation trees Dobrushin condition. 19 / 47
Locally tree like graphs BP is accurate for locally tree like graphs Heuristic: sparse graphs (without many edges) can be locally tree like i How to check? If girth(g) = O(log(n)) girth(g) is the length of the shortest cycle Random (dv, d c) graph is locally tree like with high probability 20 / 47
Dobrushin condition sufficient condition for convergence of BP Influence of node j on node i: C ij = sup x,x { µ(x i x V \i,j ) µ(x i x V \i,j ) T V : x l = x l l j} Intuition: if the influence of the nodes on each other is very small, then oscillations should not occur Theorem If γ = sup i ( j C ij) < 1, then BP marginals converge to the true marginals. 21 / 47
Summary of BP Belief propagation: ν t+1 a j (x j) = ˆν t+1 j a (x j) = {x j:j a\j} b j\a ψ a (x a ) ν t b j(x j ) l a\j ˆν t l a(x l ) i Exact on trees Accurate on locally tree like graphs No generic proving technique that works on all problems Successful in many applications, even when we cannot prove convergence 22 / 47
More challenges What if µ is a pdf on R n? Extension seems straightforward ν t+1 a j (x j) = ˆν t+1 j a (x j) = {x j:j a\j} b j\a But, calculations become complicated ν t b j(x j ) ψ a (x a ) l a\j ˆν t l a(x l )dx Shall discretize the pdf with certain accuracy Do lots of numeric multiple integrtions Exception: Gaussian graphical models µ(x) = exp( x t Jx+c t x) Z All messages become Gaussian NOT the focus of this talk 23 / 47
Approximate message passing 24 / 47
Model n N k sparse signal y A w y = Ax o + w x x o: k-sparse vector in R N A: n N design matrix y: measurement vector in R n w: measurement noise in R n 25 / 47
LASSO Many useful heuristic algorithms: Noiseless: l1-minimization minimize x subject to x 1 y = Ax x 1 = i xi Noisy: LASSO minimize x 1 2 y Ax 2 2 + λ x 1 convex optimizations Chen, Donoho, Saunders (96), Tibshirani (96) 26 / 47
LASSO Many useful heuristic algorithms: Noiseless: l1-minimization minimize x subject to x 1 y = Ax x 1 = i xi Noisy: LASSO minimize x 1 2 y Ax 2 2 + λ x 1 convex optimizations Chen, Donoho, Saunders (96), Tibshirani (96) 27 / 47
Algorithmic challenges of sparse recovery Use convex optimization tools to solve LASSO Computational complexity: interior point method: o generic tools o appropriate for N < 5000 homotopy methods: o use the structure of the LASSO o appropriate for N < 50000 First order methods o Low computational complexity per iteration o Require many iterations 28 / 47
Algorithmic challenges of sparse recovery Use convex optimization tools to solve LASSO Computational complexity: interior point method: o generic tools o appropriate for N < 5000 homotopy methods: o use the structure of the LASSO o appropriate for N < 50000 First order methods o Low computational complexity per iteration o Require many iterations 29 / 47
Algorithmic challenges of sparse recovery Use convex optimization tools to solve LASSO Computational complexity: interior point method: o generic tools o appropriate for N < 5000 homotopy methods: o use the structure of the LASSO o appropriate for N < 50000 First order methods o Low computational complexity per iteration o Require many iterations 30 / 47
Message passing for LASSO Define: µ(dx) = 1 Z e βλ x 1 β 2 y Ax 2 dx Z: normalization constant β > 0 As β : µ concentrates around solution of LASSO: ˆx λ Marginalize µ to obtain ˆx λ 31 / 47
Message passing for LASSO Define: µ(dx) = 1 Z e βλ x 1 β 2 y Ax 2 dx Z: normalization constant β > 0 As β : µ concentrates around solution of LASSO: ˆx λ Marginalize µ to obtain ˆx λ 32 / 47
Message passing for LASSO Define: µ(dx) = 1 Z e βλ x 1 β 2 y Ax 2 dx Z: normalization constant β > 0 As β : µ concentrates around solution of LASSO: ˆx λ Marginalize µ to obtain ˆx λ 33 / 47
Message passing algorithm µ(dx) = 1 Z e βλ i xi β 2 a (ya (Ax)a)2 dx Belief propagation update rules: ν t+1 i a (si) = e βλ s i b a ˆνt b i(s i) ˆν t a i(s i) = e β 2 (ya (As)a)2 j i dνt j a(s j) Pearl (82) 34 / 47
Issues of message passing algorithm Challenges: ν t+1 i a (s i) = e βλ si b a ˆν t b i(s i ) ˆν a i(s t i ) = e β 2 (ya (As)a)2 dνj a(s t j ) j i messages are distributions over real line calculation of messages is difficult o Numeric integration o The number of variables in the integral is N 1 The graph is not "tree like" o No good analysis tools 35 / 47
Challenges for message passing algorithm Challenges: Solution: ν t+1 i a (s i) = e βλ si b a ν t b i(s i ) ˆν a i(s t i ) = e β 2 (ya (As)a)2 messages are distributions over real line calculation of messages is difficult The graph is not "tree like" blessing of dimensionality: o as N : ˆν t a i (s i) converges to Gaussian dνj a(s t j ) j i o as N : νi a t (s i) converges to 1 β f β (s; x, b) z β (x, b) e β s 2b (s x)2. 36 / 47
Challenges for message passing algorithm Challenges: Solution: ν t+1 i a (s i) = e βλ si b a ν t b i(s i ) ˆν a i(s t i ) = e β 2 (ya (As)a)2 messages are distributions over real line calculation of messages is difficult The graph is not "tree like" blessing of dimensionality: o as N : ˆν t a i (s i) converges to Gaussian dνj a(s t j ) j i o as N : νi a t (s i) converges to 1 β f β (s; x, b) z β (x, b) e β s 2b (s x)2. 37 / 47
Simplifying messages for large N Simplification of ˆν a i t (s i) ˆν a i(s t i ) = e β 2 (ya (As)a)2 dνj a(s t j ) j i = E exp β 2 (y a A ai s i A aj s j ) 2 j i ( = E exp β 2 (A ais i Z) 2) Z = y a j i A ajs j d W ; W is Gaussian Ehsi (Z) Eh si (W ) C t N 1/2 (ˆτ a i t, )3/2 38 / 47
Message passing As N, ˆν a i t converges to a Gaussian distribution. Theorem Let x t j a and (τ j a t /β) denote the mean and variance of distribution νj a t. Then there exists a constant C t such that sup ˆν a i(s t i ) ˆφ t a i(s i ) s i ˆφ t a i(ds i ) C t N(ˆτ a i t, )3 βa 2 ai 2πˆτ t a i e β(aaisi zt a i )2 /2ˆτ t a i dsi, where z t a i y a j i A aj x t j a, ˆτ t a i j i A 2 ajτ t j a. Donoho, M., Montanari (11) 39 / 47
Shape of ν t i a; intuition Summary ˆν t b i (s i) Gaussian ν t+1 i a (si) = e βλ s i b a νt b i(s i) Shape of ν t+1 i a (s i) ν t+1 i a (si) 1 z β (x,b) e β s β 2b (s x)2 40 / 47
Message passing (cont d) Theorem Suppose that ˆν t a i = ˆφ t a i, with mean zt a i and variance ˆτ t a i = ˆτ t. Then at the next iteration we have ν t+1 i a (s i) φ t+1 i a (s i) < C/n, φ t+1 i a (s i) λ f β (λs i ; λ b a A bi z t b i, λ 2 (1 + ˆτ t )), and 1 f β (s; x, b) z β (x, b) e β s β 2b (s x)2. Donoho, M., Montanari (11) 41 / 47
Next asymptotic: β Reminder: x t j a mean of ν t j a z t a i mean of ˆν t a i z t a i y a j i Aajxt j a Can we write x t j a in terms of zt a i? x t j a = 1 z β s i e β si β(si b a A biz t b i )2 Can be done, but... Laplace method simplifies it (β ) x t j a = η( b a A bi z t b i). 42 / 47
Summary x t j a mean of νt j a z t a i mean of ˆνt a i z t a i = y a j i A aj x t j a x t+1 i a = η( b a A bi z t b i). 43 / 47
Comparison of MP and IST MP IST z t a i = y a j i A aj x t j a z t a i = y a j i A aj x t j a x t+1 i a = η( b a A bi z t b i). x t+1 i a = η( b A bi z t b i). x 1 x 1 y 1 y 1 y 2 x 2 y 2 x 2 y n y n x n x n 44 / 47
Approximate message passing We can further simplify the messages x t+1 = η(x t + A T z t ; λ t ) z t = y Ax t + xt 0 n zt 1 45 / 47
IST versus AMP IST: x t+1 = η(x t + A T z t ; λ t ) z t = y Ax t AMP: x t+1 = η(x t + A T z t ; λ t ) z t = y Ax t + 1 n xt 0z t 1 Mean square error 0.05 0.04 0.03 0.02 0.01 State evolution Empirical 0 0 10 20 30 40 Iteration Mean square error 0.05 0.04 0.03 0.02 0.01 0 State evolution Empirical 0.01 0 10 20 30 40 Iteration 46 / 47
Theoretical guarantees We can predict the performance of AMP at every iteration in asymptotic regime Idea: x t + A T z t can be modeled as signal plus Gaussian noise We keep track of noise variance I will discuss this part in the "Risk Seminar" 47 / 47