Efficient Spectral Methods for Learning Mixture Models Qingqing Huang 2016 February Laboratory for Information & Decision Systems Based on joint works with Munther Dahleh, Rong Ge, Sham Kakade, Greg Valiant. 1
2 Learning Data generation X() Learning Infer about the underlying rule θ (Estimation, approximation, property testing, optimization of f(θ) ) Challenge: Exploit our prior for structure of the underlying θ to design fast algorithm that uses as few as possible data X to achieve the target accuracy in learning θ Computation Complexity Sample Complexity dim()
3 Learning Mixture Models Samples Hidden variable H 2 {1,...,K} Observed variables X =(X 1,X 2,...,X M ) 2 X Marginal distribution of the observables is a superposition of simple distributions Pr(X) = KX k=1 Pr(H = k) {z } Pr(X H = k) {z } = ( #mixture components, mixing weights, conditional probabilities ) Given N i.i.d. samples of observable variables, estimate the model parameters b apple b
4 Examples of Mixture Models Gaussian Mixtures (GMMs) Cluster data points in space Topic Models (Bag of Words) Topic words in each document
5 Examples of Mixture Models Hidden Markov Models (HMM) Current state Past, current, future observations Super-Resolution Source Complex sinusoids
6 Learning Mixture Models Samples Hidden variable H 2 {1,...,K} Observed variables X =(X 1,X 2,...,X M ) 2 X Given N i.i.d. samples of observable variables, estimate the model parameters b What do we know about sample complexity and computation complexity? There exist exponential lower bounds for sample complexity (worst case analysis) Maximum Likelihood Estimation is non-convex optimization (EM is heuristics)
7 Challenges in Learning Mixture Models There exist exponential lower bounds for sample complexity (worst case analysis) Can we learn non-worst-cases with provably efficient algorithms? Maximum Likelihood Estimation is non-convex optimization (EM is heuristics) Can we achieve statistical efficiency with tractable computation?
8 Contribution / Outline of the talk There exist exponential lower bounds for sample complexity (worst case analysis) Can we learn non-worst-cases with provably efficient algorithms? ü Spectral algorithms for learning GMMs, HMMs, Super-resolution ü Go beyond worst cases with randomness in analysis and algorithm Maximum Likelihood Estimation is non-convex optimization (EM is heuristics) Can we achieve statistical efficiency with tractable computation? ü Denoising low rank probability matrix with linear sample complexity (minmax optimal)
9 Paradigm of Spectral Algorithms for Learning Mixture models Mixture of conditional distributions Samples X() N = Easy inverse problems Compute statistics: e.g. Moments, Marginal probability Spectral decomposition to separate the mixtures A B C = M Mixture of rank-one matrices/tensors
Paradigm of Spectral Algorithms for Learning Samples X() N finite Compute statistics Easy inverse problems b ba b B b C = c M Spectral decomposition A B C = M ü PCA, Spectral clustering, Subspace system ID, fit into this paradigm ü Problem specific algorithm design (what statistics M to use?) analysis (is the spectral decomposition stable?) 10
11 PART 1 Provably efficient spectral algorithms for learning mixture models 1.1 Learn GMMs: Go beyond worst cases by smoothed analysis 1.2 Learn HMMs: Go beyond worst cases by generic analysis 1.3 Super-resolution: Go beyond worst cases by randomized algorithm
12 1.1 Learn GMMs: Setup Cluster M-dim data points [k] x 2 R n mixture of k multivariate Gaussians data points in n-dimensional space Model Parameters: weights w i means µ (i) covariance matrices (i) x = N (µ (i), (i) ), i w i
13 1.1 Learn GMMs: Prior Works General case Moment matching method [Moitra&Valiant] [Belkin&Sinha] Poly(n, e O(k)k ) With restrictive assumptions on model parameters Mean vectors are well-separated Pair wise clustering [Dasgupta] [Vempala&Wang] Poly(n, k) Mean vectors of spherical Gaussians are linearly independent Moments tensor decomposition [Hsu&Kakade] Poly(n, k)
1.1 Learn GMMs: Worst case lower bound Can we learn every GMM instance to target accuracy in poly runtime and using poly samples? No Exponential dependence in k for worst cases. [Moitra&Valiant] Can we learn most GMM instances with poly algorithm? without restrictive assumptions on model parameters Yes Worst cases are not everywhere, and we can handle non-worst-cases 14
1.1 Learn GMMs: Smoothed Analysis Framework Escape from the worst cases For an arbitrary instance Nature perturbs the parameters with a small amount (ρ) of noise e Observe data generated by e, design and analyze algorithm for e Hope With high probability over nature s perturbation, any arbitrary instance escapes from the degenerate cases, and becomes well conditioned. ü Bridge worst case and average case algo analysis [Spielman&Teng] ü A stronger notion than generic analysis Our Goal: Given samples from perturbed GMM, learn the perturbed parameters with negligible failure probability over nature s perturbation 15
1.1 Learn GMMs Main Results Our algorithm learns the GMM parameters up to accuracy ε With fully polynomial time and sample complexity Assumption: data in high enough dimension Poly(n, k, 1/ ) n = (k 2 ) Under smoothed analysis: works with negligible failure probability Learning Mixture of Gaussians in High dimensions R. Ge, H, S. Kakade (STOC 2015) 16
1.1 Learn GMMs: Algorithmic Ideas Method of moments: match 4-th and 6-th order moments M 4 M 6 Key challenge: Moment tensors are not of low rank, but they have special structures X() b ba b B b C = c M A B C = M M 4 = E[x 4 ] M 6 = E[x 6 ] X 4 = X 6 = kx (i) (i), i=1 kx (i) (i) (i). i=1 Structured linear projection M 4 = F 4 (X 4 ) M 6 = F 6 (X 6 ) Moment tensors are structured linear projections of desired low rank tensors Delicate algorithm to invert the structured linear projections 17
18 1.1 Learn GMMs: Algorithmic Ideas Method of moments: match 4-th and 6-th order moments M 4 M 6 Key challenge: Moment tensors are not of low rank, but they have special structures X() b ba b B b C = c M A B C = M M 4 = E[x 4 ] M 6 = E[x 6 ] Why high dimension n & smoothed analysis help us to learn? We have many moment matching constraints with only low order moments # free parameters (kn 2 ) < #6-th moments (n 6 ) The randomness in nature s perturbation makes matrices/tensors well-conditioned Gaussian matrix X 2 R n m, with prob at least 1 O( n ) m(x) p n.
19 PART 1 Provably efficient spectral algorithms for learning mixture models 1.1 Learn GMMs: Go beyond worst cases by smoothed analysis 1.2 Learn HMMs: Go beyond worst cases by generic analysis 1.3 Super-resolution: Go beyond worst cases by randomized algorithm
1.2 Learn HMMs: Setup h t n h t 1 h t h t+1 h t+n Hidden state [k] x t n x t x t 1 x t+1 x t+n Observation [d] N = 2n+1 window size Transition probability matrix: Observation probabilities: Q 2 R k k O 2 R d k Given length-n sequences of observation, how to recover Our focus: How large the window size N needs to be? =(Q, O) 20
1.2 Learn HMMs: Hidden state [k] Observation [d] N = 2n+1 window size Hardness Results HMM is not efficiently PAC learnable, under noisy parity assumption Construct an instance with reduction to parity of noise Required window size N = (k), Algorithm Complexity is [Abe,Warmuth] [Kearns] (d k ) Excluding a measure 0 set in the parameter space of Our Result for all most all HMM instances, the required window size is =(Q, O) Spectral algo achieves sample complexity and runtime both N = (log d k) poly(d, k) Minimal Realization Problems for Hidden Markov Models H, R. Ge, S. Kakade, M. Dahleh (IEEE Transactions on Signal Processing, 2016) 21
1.2 Learn HMMs: Algorithmic idea b X() ba B b C b = M c A B C = M M =Pr (x n 1,...,x 1 ),x 0, (x 1,...,x n ) 2 R dn d n d 1. is a low rank tensor of rank k M A =Pr x 1,x 2,...,x n h 0 2. Extract Q, O from tensor factors A B B =Pr x 1,x 2,...,x n h 0 C =Pr x 0,h 0 A =(O (O (O {z...(o O Q)...)Q)Q)Q, } {z } n n eq)...) Q) e Q) e Q e, B =(O (O (O {z...(o O } n {z } n Key challenge: C = Odiag( ) How large window size needs to be, so that we have unique tensor decomp Our careful generic analysis: If N = (log d k), worst cases all lie in a measure 0 set 22
23 PART 1 Provably efficient spectral algorithms for learning mixture models 1.1 Learn GMMs: Go beyond worst cases by smoothed analysis 1.2 Learn HMMs: Go beyond worst cases by generic analysis 1.3 Super-resolution: Go beyond worst cases by randomized algorithm
1.3 Super-Resolution: Setup Band-limited X() * = Take Fourier measurement of the band-limited observation How to recover the point sources with coarse measurement of the signal? ü small number of Fourier measurements ü Low cutoff frequency 24
25 1.3 Super-Resolution: Problem Formulation ü Recover point sources (a mixture of k vectors in d-dimensional space) x(t) = P k j=1 w j µ (j). Assume minimum separation =min j6=j 0 kµ (j) µ (j0) k 2 ü Use bandlimited and noisy Fourier measurements. ef(s) = ksk 1 apple cuto freq kx j=1 w j e i <µ(j),s> + z(s) bounded noise z(s) apple z, 8s ü Achieve target accuracy kbµ (j) µ (j) k 2 apple, 8j 2 [k]
26 1.3 Super-Resolution: kx ef(s) = w j e i <µ(j),s> + z(s) j=1 Prior Works =min j6=j 0 kµ (j) µ (j0) k 2 1-dimensional µ (j) Take uniform measurements on the grid s 2 { N,..., 1, 0, 1,...,N} SDP algorithm with cut-off frequency N = ( 1 ) [Candes, Fernandez-Granda] Hardness result N> C [Moitra] k log(k) One can use random measurements to recover 2N measurements [Tang, Bhaskar, Shah, Recht] d-dimensional µ (j) Mult-dim grid Algorithm complexity s 2 { N,..., 1, 0, 1,...,N} d O poly(k, 1 ) d
1.3 Super-Resolution: Main Result Our algorithm achieves stable recovery using a number of O((k + d) 2 ) Fourier measurements cutoff freq of the measurements bounded by algorithm runtime O((k + d) 3 ) O(1/ ) algorithm works with negligible failure probability cuto freq measurements runtime SDP C d 1 ( 1 1 )d poly(( 1 1 )d,k) MP - - - Ours log(kd) (k log(k)+d) 2 (k log(k)+d) 2 Super-Resolution off the Grid H, S. Kakade (NIPS 2015) 27
1.3 Super-Resolution: ef(s) = Algorithmic Idea kx w j e i <µ(j),s> + z(s) j=1 X() ba b B b C = c M A B C = M ü Tensor decomposition with measurements on random frequencies d ü Random samples s such that F admits particular low rank decomp b V S = 2 6 4 F = V S 0 V S 0 (V 2 D w ), e i <µ(1),s (1)>... e i <µ(k),s (1) > e i <µ(1),s (2)>... e i <µ(k),s (2) >..... e i <µ(1),s (m)>... e i <µ(k),s (m) > 3 7 5. (Rank-k 3-way tensor) d d 2 (Vandermonde Matrix with complex nodes) d k ü Skip intermediate step of recovering (N d ) ü Prony s method (Matrix-Pencil / MUSIC / ) is just choosing measurements on the hyper-grid s deterministically 28
29 1.3 Super-Resolution: Algorithmic Idea ² Tensor decomposition with measurements on random frequencies ² Why we do not contradict the hardness result? O((k + d) 2 ) ü If we design a fixed grid of to take measurements f(s) there always exists model instances such that the particular grid fails s s ü Let the locations of be random (with structure for tensor decomp) then for any model instance, algo works with high probability vs O poly(k, 1 ) d
PART 2 Can we achieve optimal sample complexity in a tractable way? Estimate low rank probability matrices with linear sample complexity This problem is at the core of many spectral algorithms We capitalize the insights from community detection to solve it 30
2. Low rank matrix Setup X Probability Matrix B 2 R M M + N i.i.d. samples (distribution over M 2 outcomes) ( freq counts over outcomes) M 2 B is of rank 2: B = pp > + qq > B = Poisson(NB).18.14.08.07.07.40.15 5 3 2 1 1.14.29.09.07.10.20.40 3 4 1 0 1.08.09.05.40.04.08.07.04.04.04.15.15.15.10 M =5 2 2 1 0 1 2 1 0 1 0 N = 20.07.10.04.05.05 B.10 p.20 q 1 2 1 0 0 B Goal: find rank-2 bb such that kb b Bk 1 apple N sample complexity: upper bound algorithm, lower bound 31
2. Low rank matrix Connection to mixture models Topic model Pr(word 1, word 2 topic = T 1 )=pp > Pr(word 1, word 2 topic = T 2 )=qq > B joint distribution over word pairs HMM Pr(output 1, output 2 state = S i )=O i (OQ i ) > B distribution of consecutive outputs N data samples Extract parameters estimates empirical counts find low rank close to B b B B 32
2. Low rank matrix Sub-Optimal Attempt Probability Matrix B 2 R M M + B is of rank 2: B = > > + X N i.i.d. samples B = Poisson(NB) MLE is non-convex optimization L let s try something spectral J 33
34 2. Low rank matrix Sub-Optimal Attempt Probability Matrix B 2 R M M + B is of rank 2: B = > > + X N i.i.d. samples B = Poisson(NB) 1 N B = 1 N Poisson(NB) B, as N 1 Set bb to be the rank 2 truncated SVD of 1 N B To achieve accuracy kb b Bk 1 apple need N = (M 2 log M) Not sample efficient Hopefully N = (M) Small data in practice Word distribution in language has fat tail. More sample documents N, larger the vocabulary size M
2. Low rank matrix Main Result Our upper bound algorithm: Rank-2 estimate with accuracy Using N = O(M/ 2 ) number of sample counts Runtime O(M 3 ) We prove (strong) lower bound: bb k b B Bk 1 apple Lead to improved spectral algorithms for learning (M) Need a sequence of observations to test whether the sequence is i.i.d. of unif (M) or generated by a 2-state HMM Testing property is no easier than estimating? 8 > 0 Recovering Structured Probability Matrices H, S. Kakade, W. Kong, G. Valiant, (submitted to STOC 2016) 35
2. Low rank matrix Algorithmic Idea We capitalize the idea of community detection in sparse random network, which is a special case of our problem formulation M nodes 2 communities Expected connection Adjacency matrix B = pp > + qq > B = Bernoulli(NB).09.09.09.02.02.02.09.09.09.02.02.02.09.09.09.02.02.02.30.30.30.03.03.03 generate 1 1 0 0 1 0 1 1 1 0 1 1 0 1 1 0 1 0.02.02.02.09.09.09.02.02.02.09.09.09.02.02.02.09.09.09.03.03.03.30.30.30 estimate 0 0 0 0 1 1 1 1 1 1 1 1 0 1 0 0 1 1 B p q B 36
37 2. Low rank matrix Algorithmic Idea We capitalize the idea of community detection in sparse random network, which is a special case of our problem formulation M nodes 2 communities Expected connection Adjacency matrix B = pp > + qq > B = Bernoulli(NB) Regularize Truncated SVD: [Le, Levina, Vershynin] remove heavy row/column from B, run rank-2 SVD on the remaining graph.09.09.09.02.02.02.09.09.09.02.02.02.09.09.09.02.02.02.02.02.02.09.09.09.02.02.02.09.09.09.02.02.02.09.09.09 B.30.30.30.03.03.03 p.03.03.03.30.30.30 q generate estimate 1 1 0 0 1 0 1 1 1 0 1 1 0 1 1 0 1 0 0 0 0 0 1 1 1 1 1 1 1 1 0 1 0 0 1 1 B
38 2. Low rank matrix Algorithmic Idea We capitalize the idea of community detection in sparse random network, which is a special case of our problem formulation M nodes 2 communities Expected connection Adjacency matrix B = pp > + qq > B = Bernoulli(NB) Regularize Truncated SVD: [Le, Levina, Vershynin] remove heavy row/column from B, run rank-2 SVD on the remaining graph M M matrix Probability matrix Sample counts B = > + > B = Poisson(NB) Key Challenge: In our general setup, we do not have homogeneous marginal probabilities
2. Low rank matrix Algorithmic Idea 1, Binning We group words according to the empirical marginal probability, divide the matrix to blocks, then apply regularized t-svd to each block Phase 1 1. Estimate non-uniform marginal b 2. Bin M words according to b i 3. Regularized t-svd in each bin bin block of B to estimate Key Challenges: ü Binning is not exact, we need to deal with spillover ü We need to piece together estimates over bins 39
40 2. Low rank matrix Algorithmic Idea 2, Refinement The coarse estimation from Phase 1 gives some global information Make use of that to do local refinement for each row / column Phase 2 1. Phase 1 gives coarse estimate for many words 2. Refine the estimate for each word use linear regression 3. Achieve sample complexity N = O(M/ 2 )
41 Conclusion Spectral methods are powerful tools for learning mixture models. We can go beyond worst case analysis by exploiting the randomness in the analysis / algorithm. To make spectral algorithms more practical, one needs careful algorithm implementation to improve sample complexity.
42 Future works Data generation X() Learning Robustness: Agnostic learning, generalization error analysis Dynamics: Extend the analysis techniques and algorithmic ideas to learning of random processes, with streaming data, iterative algorithms X() Get hands dirty with real
43 References Learning Mixture of Gaussians in High dimensions R. Ge, H, S. Kakade (STOC 2015) Super-Resolution off the Grid H, S. Kakade (NIPS 2015) Minimal Realization Problems for Hidden Markov Models H, R. Ge, S. Kakade, M. Dahleh (IEEE Transactions on Signal Processing, 2016) Recovering Structured Probability Matrices H, S. Kakade, W. Kong, G. Valiant, (submitted to STOC 2016)
Tensor Decomposition Multi-way array in matlab 2-way tensor =matrix 3-way tensor: M j1,j 2,j 3, j 1 2 [n A ],j 2 2 [n B ],j 3 2 [n C ] 45
46 Tensor Decomposition Sum of rank one tensors [a b c] j1,j 2,j 3 = a j1 b j2 c j3 M = kx i=1 A [:,i] B [:,i] C [:,i] = A B C Tensor rank: minimum number of summands in a rank decomposition kx = = i=1
47 Tensor Decomposition M = kx i=1 A [:,i] B [:,i] C [:,i] = A B C Necessary condition for unique tensor decomposition If A and B are of full rank k, and C has rank 2 we can decompose M to uniquely find the factors A B C in poly time and stability depends poly on condition number of A B C (the algorithm boils down to matrix SVD)