Bayesian Learning and Inference in Recurrent Switching Linear Dynamical Systems

Similar documents
Bayesian Learning and Inference in Recurrent Switching Linear Dynamical Systems

13: Variational inference II

Lecture 13 : Variational Inference: Mean Field Approximation

Gaussian Mixture Model

STA 4273H: Statistical Machine Learning

Coupled Hidden Markov Models: Computational Challenges

STA 4273H: Statistical Machine Learning

Pattern Recognition and Machine Learning

Computer Vision Group Prof. Daniel Cremers. 10a. Markov Chain Monte Carlo

Markov Chain Monte Carlo

Probabilistic Graphical Models for Image Analysis - Lecture 4

Variational Inference (11/04/13)

CS839: Probabilistic Graphical Models. Lecture 7: Learning Fully Observed BNs. Theo Rekatsinas

Probabilistic Graphical Models

Bayesian Methods for Machine Learning

Stochastic Variational Inference for HMMs, HSMMs, and Nonparametric Extensions

Approximate Bayesian inference

Graphical Models for Query-driven Analysis of Multimodal Data

STA414/2104. Lecture 11: Gaussian Processes. Department of Statistics

Introduction to Machine Learning

Bayesian Inference and MCMC

CPSC 540: Machine Learning

39th Annual ISMS Marketing Science Conference University of Southern California, June 8, 2017

Sampling Methods (11/30/04)

Black-box α-divergence Minimization

PMR Learning as Inference

Local Expectation Gradients for Doubly Stochastic. Variational Inference

Stochastic Variational Inference for the HDP-HMM

Two Useful Bounds for Variational Inference

The Particle Filter. PD Dr. Rudolph Triebel Computer Vision Group. Machine Learning for Computer Vision

27 : Distributed Monte Carlo Markov Chain. 1 Recap of MCMC and Naive Parallel Gibbs Sampling

STA 4273H: Statistical Machine Learning

Outline. Binomial, Multinomial, Normal, Beta, Dirichlet. Posterior mean, MAP, credible interval, posterior distribution

Variational Scoring of Graphical Model Structures

Graphical Models and Kernel Methods

STA 4273H: Statistical Machine Learning

Deep Poisson Factorization Machines: a factor analysis model for mapping behaviors in journalist ecosystem

(5) Multi-parameter models - Gibbs sampling. ST440/540: Applied Bayesian Analysis

Sequential Monte Carlo Methods for Bayesian Computation

Deep learning with differential Gaussian process flows

Nonparametric Bayesian Methods (Gaussian Processes)

Infering the Number of State Clusters in Hidden Markov Model and its Extension

Small-variance Asymptotics for Dirichlet Process Mixtures of SVMs

Machine Learning Overview

17 : Markov Chain Monte Carlo

Distributed Bayesian Learning with Stochastic Natural-gradient EP and the Posterior Server

Lecture 10. Announcement. Mixture Models II. Topics of This Lecture. This Lecture: Advanced Machine Learning. Recap: GMMs as Latent Variable Models

STA 414/2104: Machine Learning

Linear Dynamical Systems (Kalman filter)

Variational Inference: A Review for Statisticians

Non-Parametric Bayes

CSCI-567: Machine Learning (Spring 2019)

Machine Learning for Data Science (CS4786) Lecture 24

Afternoon Meeting on Bayesian Computation 2018 University of Reading

Probabilistic Machine Learning

Probabilistic Graphical Models

Introduction to Machine Learning Midterm, Tues April 8

Integrated Non-Factorized Variational Inference

Variational Message Passing. By John Winn, Christopher M. Bishop Presented by Andy Miller

13 : Variational Inference: Loopy Belief Propagation and Mean Field

Approximate Inference using MCMC

ECE521 Tutorial 11. Topic Review. ECE521 Winter Credits to Alireza Makhzani, Alex Schwing, Rich Zemel and TAs for slides. ECE521 Tutorial 11 / 4

Towards a Bayesian model for Cyber Security

Density Estimation. Seungjin Choi

Collapsed Variational Dirichlet Process Mixture Models

Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm

Machine Learning for Data Science (CS4786) Lecture 19

An Overview of Edward: A Probabilistic Programming System. Dustin Tran Columbia University

On Markov chain Monte Carlo methods for tall data

Natural Gradients via the Variational Predictive Distribution

Stochastic Variational Inference for Bayesian Time Series Models

Markov Chain Monte Carlo Methods for Stochastic Optimization

Particle Filtering Approaches for Dynamic Stochastic Optimization

Computer Vision Group Prof. Daniel Cremers. 11. Sampling Methods: Markov Chain Monte Carlo

Stochastic Variational Inference

Learning Bayesian network : Given structure and completely observed data

Bayesian Machine Learning - Lecture 7

Undirected Graphical Models

17 : Optimization and Monte Carlo Methods

Faster HSMM Inference with E cient Representations

The connection of dropout and Bayesian statistics

Pattern Recognition and Machine Learning. Bishop Chapter 9: Mixture Models and EM

Chapter 4 Dynamic Bayesian Networks Fall Jin Gu, Michael Zhang

28 : Approximate Inference - Distributed MCMC

Dynamic Approaches: The Hidden Markov Model

Probabilistic Graphical Models

Bayesian Nonparametrics: Dirichlet Process

Exponential Families

Introduction to Machine Learning CMU-10701

Learning with Noisy Labels. Kate Niehaus Reading group 11-Feb-2014

Unsupervised Learning

Study Notes on the Latent Dirichlet Allocation

Introduction to Probabilistic Machine Learning

Wageningen Summer School in Econometrics. The Bayesian Approach in Theory and Practice

Markov Chain Monte Carlo Methods for Stochastic

An introduction to Sequential Monte Carlo

Bayesian Nonparametric Learning of Complex Dynamical Phenomena

Lecture 2: From Linear Regression to Kalman Filter and Beyond

DEPARTMENT OF COMPUTER SCIENCE Autumn Semester MACHINE LEARNING AND ADAPTIVE INTELLIGENCE

Stochastic Gradient Descent

Transcription:

Bayesian Learning and Inference in Recurrent Switching Linear Dynamical Systems Scott W. Linderman Matthew J. Johnson Andrew C. Miller Columbia University Harvard and Google Brain Harvard University Ryan P. Adams David M. Blei Liam Paninski Harvard and Google Brain Columbia University Columbia University Stochastic Variational Inference The main paper introduces a Gibbs sampling algorithm for the recurrent SLDS and its siblings, but it is straightforward to derive a mean field variational inference algorithm as well. From this, we can immediately derive a stochastic variational inference (SVI) [Ho man et al., 203] algorithm for conditionally independent time series. We use a structured mean field approximation on the augmented model, p(z :T,x :T,! :T, y :T ) q(z :T ) q(x :T ) q(! :T ) q( ; ). The first three factors will not be explicitly parameterized; rather, as with Gibbs sampling, we leverage standard message passing algorithms to compute the necessary expectations with respect to these factors. Moreover, q(! :T ) further factorizes as, q(! :T )= TY KY t= k= q(! t,k ). To be concrete, we also expand the parameter factor, q( ; ) = KY q(r k,r k rec,k ) q(a k,b k,b k ; dyn,k ) k= q(c k,d k,d k ; obs,k ). The algorithm proceeds by alternating between optimizing q(x :T ), q(z :T ), q(! :T ), and q( ). Updating q(x :T ). Fixing the factor on the discrete states q(z :T ), the optimal variational factor on the continuous states q(x :T )isdeterminedby, ln q(x :T )= (x )+ where + t= t= (x t,x t+ ) (x t,z t+,! t )+ t= (x t ; y t )+c. (x )=E q( )q(z) ln p(x z, ) () (x t,x t+ )=E q( )q(z) ln p(x t+ x t,z t, ), (2) (x t,z t+ )=E q( )q(z)q(!) ln p(z t+ x t,z t,! t, ), (3) Because the densities p(x z, ) and p(x t+ x t,z t, ) are Gaussian exponential families, the expectations in Eqs. ()-(2) can be computed e ciently, yielding Gaussian potentials with natural parameters that depend on both q( ) and q(z :T ). Furthermore, each (x t ; y t ) is itself a Gaussian potential. As in the Gibbs sampler, the only non-gaussian potential comes from the logistic stick breaking model, but once again, the Pólya-gamma augmentation scheme comes to the rescue. After augmentation, the potential as a function of x t is, E q( )q(z)q(!) ln p(z t+ x t,z t,! t, ) = 2 T t+ t t+ + T t+apple(z t+ )+c. Since t+ = R zt x t + r zt is linear in x t, this is another Gaussian potential. As with the dynamics and observation potentials, the recurrence weights, (R k,r k ), also have matrix normal factors, which are conjugate after augmentation. We also need access to E q [! t,k ]; we discuss this computation below. After augmentation, the overall factor q(x :T ) is a Gaussian linear dynamical system with natural parameters computed from the variational factor on the

Supplementary Material dynamical parameters q( ), the variational parameter on the discrete states q(z :T ), the recurrence potentials { (x t,z t,z t+ )} T t=, and the observation model potentials { (x t ; y t )} T t=. Because the optimal factor q(x :T ) is a Gaussian linear dynamical system, we can use message passing to perform e cient inference. In particular, the expected su cient statistics of q(x :T ) needed for updating q(z :T ) can be computed e ciently. Updating q(! :T ). We have, ln q(! t,k )=E q ln p(z t+! t,x t )+c = 2 E q[ 2 t+]! t,k + E q(z:t ) ln p PG (! t,k I[z t+ k], 0) + c While the expectation with respect to q(z :T ) makes this challenging, we can approximate it with a sample, ẑ :T q(z :T ). Given a fixed value ẑ :T we have, q(! t,k )=p PG (! t,k I[ẑ t+ k], E q [ 2 t+]). The expected value of the distribution is available in closed form: E q [! t,k ]= I[ẑ t+ k] 2E q [ 2 t+ ] tanh 2 E q[ 2 t+]. Updating q(z :T ). Similarly, fixing q(x :T ) the optimal factor q(z :T ) is proportional to exp where ( (z )+ t= (z t,x t,z t+ )+ t= (z t ) (z )=E q( ) ln p(z )+E q( )q(x) ln p(x z, ) (z t,x t,z t+ )=E q( )q(x:t ) ln p(z t+ z t,x t ) (z t )=E q( )q(x) ln p(x t+ x t,z t, ) The first and third densities are exponential families; these expectations can be computed e ciently. The challenge is the recurrence potential, (z t,x t,z t+ )=E q( ),q(x) ln SB ( t+ ). Since this is not available in closed form, we approximate this expectation with Monte Carlo over x t, R k, and r k. The resulting factor q(z :T ) is an HMM with natural parameters that are functions of q( ) and q(x :T ), and the expected su cient statistics required for updating q(x :T ) can be computed e ciently by message passing in the same manner. ), Updating q( ). To compute the expected su cient statistics for the mean field update on, we can also use message passing, this time in both factors q(x :T ) and q(z :T ) separately. The required expected su - cient statistics are of the form E q(z) I[z t = i, z t+ = j], E q(z) I[z t = i], E q(z) I[z t = k]e q(x) xt x T t+, (4) E q(z) I[z t = k]e q(x) xt x T t, Eq(z) I[z = k]e q(x) [x ], where I[ ] denotes an indicator function. Each of these can be computed easily from the marginals q(x t,x t+ ) and q(z t,z t+ ) for t =, 2,...,T, and these marginals can be computed in terms of the respective graphical model messages. Given the conjugacy of the augmented model, the dynamics and observation factors will be MNIW distributions as well. These allow closed form expressions for the required expectations, E q [A k ], E q [b k ], E q [A k B k ], E q[b k B k ], E q[b k ], E q [C k ], E q [d k ], E q [C k D k ], E q[d k D k ], E q[d k ]. Likewise, the conjugate matrix normal prior on (R k,r k ) provides access to E q [R k ], E q [R k R T k ], E q [r k ]. Stochastic Variational Inference. Given multiple, conditionally independent observations of time series, {y (p) :T p } P p= (using the same notation as in the basketball experiment), it is straightforward to derive a stochastic variational inference (SVI) algorithm [Ho man et al., 203]. In each iteration, we sample a random time series; run message passing to compute the optimal local factors, q(z (p) :T p ), q(x (p) :T p ), and q(! (p) :T p ); and then use expectations with respect to these local factors as unbiased estimates of expectations with respect to the complete dataset when updating the global parameter factor, q( ). Given a single, long time series, we can still derive e cient SVI algorithms that use subsets of the data, as long as we are willing to accept minor, controllable bias [Johnson and Willsky, 204, Foti et al., 204]. 2 Initialization Given the complexity of these models, it is important to initialize the parameters and latent states with reasonable values. We used the following initialization procedure: (i) use probabilistic PCA or factor analysis to initialize the continuous latent states, x :T, and the observation, C, D, and d; (ii) fit an AR-HMM to x :T in order to initialize the discrete latent states, z :T,

Linderman, Johnson, Miller, Adams, Blei, and Paninski and the dynamics models, {A k,q k,b k }; and then (iii) greedily fit a decision list with logistic regressions at each node in order to determine a permutation of the latent states most amenable to stick breaking. In practice, the last step alleviates the undesirable dependence on ordering that arises from the stick breaking formulation. As mentioned in Section 4, one of the less desirable features of the logistic stick breaking regression model is its dependence on the ordering of the output dimensions; in our case, on the permutation of the discrete states {, 2,...,K}. To alleviate this issue, we first do a greedy search over permutations by fitting a decision list to (x t,z t ),z t+ pairs. A decision list is an iterative classifier of the form, 8 o if I[p ] o >< 2 if I[ p ^ p 2 ] z t+ = o 3 if I[ p ^ p 2 ^ p 3 ]. >: o K o.w., where (o,...,o K ) is a permutation of (,...,K), and p,...,p k are predicates that depend on (x t,z t ) and evaluate to true or false. In our case, these predicates are given by logistic functions, p j = (r T j x t ) > 0. We fit the decision list using a greedy approach: to determine o and r, we use maximum a posterior estimation to fit logistic regressions for each of the K possible output values. For the k-th logistic regression, the inputs are x :T and the outputs are y t = I[z t+ = k]. We choose the best logistic regression (measured by log likelihood) as the first output. Then we remove those time points for which z t+ = o from the dataset and repeat, fitting K logistic regressions in order to determine the second output, o 2, and so on. After iterating through all K outputs, we have a permutation of the discrete states. Moreover, the predicates {r k } K k= serve as an initialization for the recurrence weights, R, in our model. 3 Bernoulli-Lorenz Details The Pólya-gamma augmentation makes it easy to handle discrete observations in the rslds, as illustrated in the Bernoulli-Lorenz experiment. Since the Bernoulli likelihood is given by, p(y t z t, )= = NY Bern( (c T nx t + d n ) n= NY n= (e ct n xt+dn ) yt,n +e ct n xt+dn, we see that it matches the form of (7) with, t,n = c T nx t + d n, b(y t,n )=, apple(y t,n )=y t,n 2. Thus, we introduce an additional set of Pólyagamma auxiliary variables, t,n PG(, 0), to render the model conjugate. Given these auxiliary variables, the observation potential is proportional to a Gaussian distribution on x t, with (x t,y t ) /N(Cx t + d t apple(y t ), t ), t = diag ([ t,,..., t,n ]), apple(y t )=[apple(y t, ),...,apple(y t,n )]. Again, this admits e cient message passing inference for x :T. In order to update the auxiliary variables, we sample from their conditional distribution, t,n PG(, t,n ). This augmentation scheme also works for binomial, negative binomial, and multinomial obsevations as well [Polson et al., 203]. 4 Basketball Details For completeness, Figures and 2 show all K = 30 inferred states of the rar-hmm (ro) for the basketball data. References Nicholas Foti, Jason Xu, Dillon Laird, and Emily Fox. Stochastic variational inference for hidden Markov models. In Advances in Neural Information Processing Systems, pages 3599 3607, 204. Matthew D Ho man, David M Blei, Chong Wang, and John William Paisley. Stochastic variational inference. Journal of Machine Learning Research, 4(): 303 347, 203. Matthew J. Johnson and Alan S. Willsky. Stochastic variational inference for Bayesian time series models. In Proceedings of the 3st International Conference on Machine Learning (ICML-4), pages 854 862, 204.

Supplementary Material Nicholas G Polson, James G Scott, and Jesse Windle. Bayesian inference for logistic models using Pólyagamma latent variables. Journal of the American Statistical Association, 08(504):339 349, 203.

Linderman, Johnson, Miller, Adams, Blei, and Paninski Figure : All of the inferred basketball states

Supplementary Material Figure 2: All of the inferred basketball states