Auto-Encoding Variational Bayes. Stochastic Backpropagation and Approximate Inference in Deep Generative Models

Similar documents
Auto-Encoding Variational Bayes

Stochastic Backpropagation, Variational Inference, and Semi-Supervised Learning

arxiv: v8 [stat.ml] 3 Mar 2014

Lecture 16 Deep Neural Generative Models

arxiv: v4 [cs.lg] 16 Apr 2015

Variational Autoencoder

Natural Gradients via the Variational Predictive Distribution

Latent Variable Models

Deep Variational Inference. FLARE Reading Group Presentation Wesley Tansey 9/28/2016

Variational Inference for Monte Carlo Objectives

Unsupervised Learning

Evaluating the Variance of

A graph contains a set of nodes (vertices) connected by links (edges or arcs)

Denoising Criterion for Variational Auto-Encoding Framework

Probabilistic Graphical Models

Variational Autoencoders (VAEs)

arxiv: v3 [cs.lg] 25 Feb 2016 ABSTRACT

Variational Bayes on Monte Carlo Steroids

Variational Dropout and the Local Reparameterization Trick

Appendices: Stochastic Backpropagation and Approximate Inference in Deep Generative Models

Local Expectation Gradients for Doubly Stochastic. Variational Inference

Scalable Non-linear Beta Process Factor Analysis

Variational Inference via Stochastic Backpropagation

UNSUPERVISED LEARNING

arxiv: v3 [stat.ml] 30 May 2014

Variational Autoencoders

Variational Inference in TensorFlow. Danijar Hafner Stanford CS University College London, Google Brain

arxiv: v4 [stat.co] 19 May 2015

Generative models for missing value completion

Bayesian Semi-supervised Learning with Deep Generative Models

Nonparametric Inference for Auto-Encoding Variational Bayes

Deep unsupervised learning

A Variance Modeling Framework Based on Variational Autoencoders for Speech Enhancement

REINTERPRETING IMPORTANCE-WEIGHTED AUTOENCODERS

Greedy Layer-Wise Training of Deep Networks

Machine Learning Techniques for Computer Vision

An Overview of Edward: A Probabilistic Programming System. Dustin Tran Columbia University

Resampled Proposal Distributions for Variational Inference and Learning

Resampled Proposal Distributions for Variational Inference and Learning

17 : Optimization and Monte Carlo Methods

STA 4273H: Statistical Machine Learning

Markov Chain Monte Carlo and Variational Inference: Bridging the Gap

Stochastic Variational Inference for Gaussian Process Latent Variable Models using Back Constraints

Combine Monte Carlo with Exhaustive Search: Effective Variational Inference and Policy Gradient Reinforcement Learning

REBAR: Low-variance, unbiased gradient estimates for discrete latent variable models

arxiv: v1 [cs.lg] 15 Jun 2016

Density estimation. Computing, and avoiding, partition functions. Iain Murray

STA 4273H: Statistical Machine Learning

13: Variational inference II

Introduction to Machine Learning

Large-Scale Feature Learning with Spike-and-Slab Sparse Coding

Probabilistic Reasoning in Deep Learning

Probabilistic Graphical Models

Deep Learning Srihari. Deep Belief Nets. Sargur N. Srihari

STA 414/2104: Machine Learning

Deep Generative Models. (Unsupervised Learning)

The connection of dropout and Bayesian statistics

STA 4273H: Statistical Machine Learning

Knowledge Extraction from DBNs for Images

Variational Auto Encoders

STA414/2104. Lecture 11: Gaussian Processes. Department of Statistics

STA 4273H: Statistical Machine Learning

A Unified View of Deep Generative Models

The Origin of Deep Learning. Lili Mou Jan, 2015

Deep latent variable models

Algorithms for Variational Learning of Mixture of Gaussians

Deep Generative Models

Variational Scoring of Graphical Model Structures

Bayesian Machine Learning

Inference Suboptimality in Variational Autoencoders

Sandwiching the marginal likelihood using bidirectional Monte Carlo. Roger Grosse

STA 4273H: Sta-s-cal Machine Learning

Variance Reduction in Black-box Variational Inference by Adaptive Importance Sampling

Recurrent Latent Variable Networks for Session-Based Recommendation

Unsupervised Learning

Variational Dropout via Empirical Bayes

STA 414/2104: Lecture 8

arxiv: v1 [stat.ml] 3 Nov 2016

Recent Advances in Bayesian Inference Techniques

Reading Group on Deep Learning Session 4 Unsupervised Neural Networks

Discrete Latent Variable Models

BACKPROPAGATION THROUGH THE VOID

Expectation Maximization

Neural Variational Inference and Learning in Undirected Graphical Models

Linear Dynamical Systems

Pattern Recognition and Machine Learning

arxiv: v5 [stat.ml] 5 Aug 2017

Part 1: Expectation Propagation

Stochastic Gradient Estimate Variance in Contrastive Divergence and Persistent Contrastive Divergence

An Information Theoretic Interpretation of Variational Inference based on the MDL Principle and the Bits-Back Coding Scheme

Understanding Covariance Estimates in Expectation Propagation

MODULE -4 BAYEIAN LEARNING

Large-scale Ordinal Collaborative Filtering

Deep Temporal Sigmoid Belief Networks for Sequence Modeling

Contrastive Divergence

Variational AutoEncoder: An Introduction and Recent Perspectives

arxiv: v2 [cs.ne] 22 Feb 2013

Learning Deep Sigmoid Belief Networks with Data Augmentation

Probabilistic Graphical Models

ECE521 Tutorial 11. Topic Review. ECE521 Winter Credits to Alireza Makhzani, Alex Schwing, Rich Zemel and TAs for slides. ECE521 Tutorial 11 / 4

Transcription:

Auto-Encoding Variational Bayes Diederik Kingma and Max Welling Stochastic Backpropagation and Approximate Inference in Deep Generative Models Danilo J. Rezende, Shakir Mohamed, Daan Wierstra Neural Variational Inference and Learning in Belief Networks Andriy Mnih, Karol Gregor September 26, 2014

Overview Deep models are becoming increasingly popular Directed probabilistic models are often computationally intensive and difficult to learn In addition to a high training cost, directed graphical models are often expensive to fit computational parameters at runtime Compare to deep neural nets: High training cost (although less than directed probabilistic models) Low testing cost (just requires matrix multiplication + simple nonlinear function evaluations) These group of papers introduce method for efficient learning of directed graphical models as well as efficient testing evaluation

Quick Review of Variational Bayes (VB) z Given a generative model, even finding θ MAP can be difficult The EM algorithm and VB x methods have been proposed to infer θ estimates N In many problems, the posterior distribution on p(z hical Figure: n x n, θ v ) is model Generative under consideration. model from Solid lines denote the generative s denote Kingmathe andvariational Welling. θ are intractible, and can lead to issues approximation q (z x) to the intractable model parameters, {x l parameters are learned n, z n} is a with VB and EM algorithms jointly with the generative model pa- sample from the model with only x n observed. φ denotes variables added in the VB inference scheme nario to the case where Reviewe by: David alsocarlson performrecognition variational Models inference on

Quick Review of Variational Bayes (VB) Let p θ (x n, z n ) denote the model likelihood for data sample n for a set of parameters θ Let q φ (z n x n ) denote a proposed distribution to approximate the intractible true posterior p θ (z n x n ) for a set of parameters φ The well-known variational equality is: log p θ (x 1:N )) = D KL (q φ (θ, z 1:N x 1:N ) p(θ, z 1:N x 1:N )) + L(θ, φ; x 1:N ) L(θ, φ; x 1:N ) = E q(θ)qφ (z 1:N x 1:N )[log p θ(x 1:N, z 1:N )p(θ) q φ1:n (z 1:N x 1:N )q(θ) ] In most VB algorithms focus on using the lower bound: log p θ (x n ) L(θ, φ 1:N ; x 1:N ) (1)

Quick Review of Variational Bayes (VB) A VB algorithm attempts to maximize the ELBO: {θ, φ} = arg max φ,θ L(θ, φ 1:N; x 1:N ) (2) Usually, a separate φ n is learned such that in the approximate distribution for z n is independent of x n given the parameters φ n, or: q φn (z x n ) = q φn (z) (3) The algorithm this would be as follows: 1 Update φ to maximize L(θ, φ; x 1,...,N ) 2 Update θ to maximize E qφ (z x n)q(θ)[log p θ (x n z)] If q(θ) is reduced to a point estimate, this recovers the EM algorithm

Recognition Model Finding the optimal φ 1:N to maximize L(θ, φ 1:N ; x 1,...,N ) is straighforward in conjugate models, but computationally expensive and intractible in many desirable models The key idea of the Recognition Model is: Don t learn sample specific VB parameters φ n Learn a shared φ that defines a function from x n to a simple to calculate and sample distribution q φ (z n x n ) Unlike normal VB, q φ (z n x n ) is dependent on x n Fitting a Recognition Model is not necessarily straightforward, and poor choices of models for the recognition part will lead to poor performance The set of papers today are focused on fitting and using in specific cases, including deep learning

Recognition Model: Sigmoid Belief Net Consider the following model: z nj Bern(π j ) for j = 1,..., J (4) v nm Bern(σ(c m + w T m z)) for m = 1,..., M (5) In a typical VB optimization scheme, each data sample n would have a separate set of parameters such that: q φn (z n x n ) = j q φ (z nj ) = j Bern(z nj ; π nj ) (6) This is computational intensive Instead, try to learn a recognition model No general form. A simple form could be: q φ (z n x n ) = j q φ (z nj x n ) = j Bern(z nj ; σ(d j + a T j x n )) (7)

Learning the recognition model The papers here are focused on different aspects of using recognition models: General learning schemes for recognition models (Kingma and Welling) Variance reduction on Monte Carlo variance estimates (Mnih and Gregor) Specific forms of recognition models for deep learning (Rezende et al) Briefly discuss the main ideas of the papers and a subsection of results

Learning the Recognition Model z x Let z (l) q φ (z x) There is a naive Monte Carlo gradient estimator for gradients of expectation of functions: φ E qφ (z xv)[f (z)] = E qφ (z xv)[f (z) φ log f (z)] 1 L f (z (l) ) φ log f (z) L N This estimator applied to the standard VB lower bound exhibits very high variance and is phical model under consideration. Solid lines denote the generative s denote the variational approximation impractical q for (z x) most to purposes the intractable considered here al parameters are learned jointly with the generative model pa- l=1

The SGVB estimator (Kingma and Welling) This algorithm notes that the variational equality can be written as (for a point estimate on θ): L(φ, θ, x n ) = D KL (q φ (z x n ) p θ (z)) + E qφ (z x n)[log p θ (x, z)] The K-L divergence is now between the prior and recognition model In standard VB models, this first term is intractable In recognition models, the first term is often explicitly (and easily) calculable The Stochastic Gradient Variational Bayes (SGVB) method estimates gradient on the first term exactly and samples using the standard scheme for the second term Empirically shows lower variance on estimator

The SGVB estimator (Kingma and Welling) Algorithm 1 Minibatch version of the Auto-Encoding VB (AEVB) algorithm. Either of the two SGVB estimators in section 2.3 can be used. We use settings M = 100 and L =1in experiments., Initialize parameters repeat X M Random minibatch of M datapoints (drawn from full dataset) Random samples from noise distribution p( ) g r, L e M (, ; X M, ) (Gradients of minibatch estimator (8)), Update parameters using gradients g (e.g. SGD or Adagrad [DHS10]) until convergence of parameters (, ) return, This algorithm was run on a Variational Auto-Encoder for the Often, the KL-divergence D KL (q (z x (i) ) p (z)) of eq. (3) can be integrated analytically (see appendix B), such that only the expected reconstruction error E q (z x (i) ) log p (x (i) z) requires estimation MNIST by sampling. and The Frey KL-divergence Face datasets term can then be interpreted as regularizing, encouraging the approximate posterior to be close to the prior p (z). This yields a second version of the SGVBThe estimator Variational L eb (, ; x (i) ) Auto-Encoder ' L(, ; x (i) ), corresponding is defined to eq. as: (3), which typically has less variance than the generic estimator: el B (, ; x (i) )= D KL (q (z x (i) ) p (z)) + 1 LX (log p (x (i) z (i,l) )) L p(x z) Bern(σ(W 2 tanh(w 1 z + b 1 ) + b 2 )) (8) l=1 q(z x) N (z; W 4 h + b 4, diag(w 5 h + b 5 )) (9) where z (i,l) = g ( (i,l), x (i) ) and (l) p( ) (7) Given multiple datapoints from a dataset X with N datapoints, we can construct an estimator of the marginal likelihood lower bound of the full dataset, based on minibatches: h = tanh(w 3 x + b 3 ) (10) L(, ; X) ' L em (, ; X M )= N Recognition M Models MX el(, ; x (i) ) (8)

The SGVB estimator (Kingma and Welling) Figure 2: Comparison of our AEVB method to the wake-sleep algorithm, in terms of optimizing the lower bound, for different dimensionality of latent space (N z ). Our method converged considerably faster and reached a better solution in all experiments. Interestingly enough, more latent variables does not result in more overfitting, which is explained by the regularizing effect of the lower bound. Vertical axis: the estimated average variational lower bound per datapoint. The estimator variance was small (< 1) and omitted. Horizontal axis: amount of training points evaluated. Computation took around 20-40 minutes per million training samples with a Intel Xeon CPU running at an effective 40 GFLOPS. decoder output. Note that with hidden units we refer to the hidden layer of the neural networks of the encoder and decoder.

Neural Variational Inference and Learning (NVIL) (Mnih and Gregor) Suppose we are interested in training a latent variable model P θ (x, h) with parameters θ x is an observation and h are latent variables Propose a model mapping x to Q φ (h x) the inference network (also recognition model ) Require that Q φ (h x) is: Efficient to evaluate Efficient to sample This work focuses on reducing variance on gradient estimates

Neural Variational Inference and Learning (NVIL) (Mnih and Gregor) The gradients are written: θ L(x, φ, θ) = E Q [ θ log P θ (x, h)] (11) φ L(x, φ, θ) = E Q [(log P θ (x, h) log Q φ (h x)) φ log Q φ (h x)] Equation 11 is simple to approximate if q φ (h x) is easy to sample Gradient estimates on φ suffer from high variance in the estimators in naive Monte Carlo Integration (MCI) schemes

Neural Variational Inference and Learning (NVIL) (Mnih and Gregor) The MCI scheme to estimate the gradient for φ is φ L(x, θ, φ) 1 n (l φ (x, h (i) )) log Q φ (h (i) x) (12) n i=1 l φ (x, h (i) ) = log P θ (x, h (i) ) log Q φ (h (i) x) This variance can be reduced by a few useful techniques First, note that this expectation is equivalent with a constant offset on l φ φ L(x, θ, φ) 1 n n (l φ (x, h (i) ) c) log Q φ (h (i) x) i=1 c can be chosen so that l φ has a mean of 0, reducing variance

Neural Variational Inference and Learning (NVIL) (Mnih and Gregor) In deep networks the layers will presumably natural factor: n 1 P θ (x, h) = P θ (x h 1 ) P θ (h i h i+1 )P θ (h n ) (13) i=1 n 1 Q φ (x, h) = Q φ 1(h 1 x) Q φ i+1(h i+1 h i ) (14) Then, for each layer, the weights on the Monte Carlo samples can be adapted: i=1 l i φ (x, h) = log P θ(h (i 1):n ) log Q φ (h (i:n) h i 1 ) + c i

Sigmoid Belief Net MNIST Results (Mnih and Gregor) Neural Variational Inference and Learning in Belief Networks 100 SBN 200 100 SBN 200 200 120 120 140 140 Validation set bound 160 180 Validation set bound 160 180 200 200 220 Baseline, IDB, & VN Baseline & VN Baseline only VN only No baselines & no VN 220 Baseline, IDB, & VN Baseline & VN Baseline only VN only No baselines & no VN 240 0 200 400 600 800 1000 1200 1400 1600 1800 2000 Number of parameter updates 240 0 200 400 600 800 1000 1200 1400 1600 1800 2000 Number of parameter updates Figure 1. Bounds on the validation set log-likelihood for an SBN with (Left) one and (Right) two layers of 200 latent variables. Baseline and IDB refer to the input-independent and the input-dependent baselines respectively. VN is variance normalization. VN keeps a estimate of standard deviation of the gradient Table 1. Results on the binarized MNIST dataset. Dim is the number of latent variables in each layer, starting with the deepest one. NVIL (similar and WS referto the ADAgrad) models trained with NVIL and wake-sleep respectively. NLL is the negative log-likelihood for the tractable models and an estimate of it for the intractableones. MODEL DIM TEST NLL NVIL WS SBN 200 113.1 120.8 SBN 500 112.8 121.4 SBN 200-200 99.8 107.7 SBN 200-200-200 96.7 102.2 SBN 200-200-500 Review by: 97.0David 102.3 Carlson form much better than a mixture of 500 factorial Bernoulli distributions (MoB) but not as well as the deterministic Neural Autoregressive Distribution Estimator (NADE) (Larochelle & Murray, 2011). The NVIL-trained fdarn models with 200 and 500 latent variables also outperform the fdarn (as well as the more expressive DARN) model with 400 latent variables from (Gregor et al., 2013), which were trained using an MDL-based algorithm. The fdarn and multi-layer SBN models trained using NVIL also outperform a 500-hidden-unit RBM trained with 3-step contrastive divergence (CD), but not the one trained with 25- step Recognition CD (Salakhutdinov Models & Murray, 2008). However, both IDB subtracts a data-independent constant c from l φ Baselines subtracts a data-dependent constant C φ (x)

Deep Latent Gaussian Models (DLGMs) (Rezende et al) h n,2 h n,1 v n n =1...N g generative............ (a) (b) µ C Let T l be a nonlinearity associated with the l th layer (may be learned) µ G l implies a covariance matrix G l G T l on the added signal, and may add low-rank C covariance µ ζ l N (ζ l 0, I) (15) C h L = G L ζ L (16) h l = T l (h l+1 ) + G l ζ l (17) v π(v T 0 (h 1 )) (18)

Recognition Model for DLGMs (Rezende et al) Stochastic Backpropagation in DLGMs hn,2 hn,1 v n n =1...N g generative......... µ C µ C µ C recognition...... The conditional d (a) (b) itly defined by eq Figure: From Rezende et al. Black arrows indicate forward pass Figure 1. (a) Graphical model for DLGMs (5). (b) The tributions of the with m model. Red arrows indicate backwards pass of the model. Solid corresponding computational graph. Black arrows indicate S lines l = G l G > l. Equ indicate the forward propagation pass of sampling of deterministic from the activations recognition and and generative models: of samples. Solid lines indicate propagation of deter- linear transformati dashed lines generative indicatemodel w propagation ministic activations, dotted lines indicate propagation of tion p( ) = Q L l=1 N samples. Red arrows indicate the backward pass for gradient computation: Solid lines indicate paths where deter- A graphical mode distribution tries t ministic backpropagation is used, dashed arrows indicate shown in figure 1(a......... expressed in two eq p(v,h)=p(v h 1, g p(v, )= p(v h 1 (

Scalable Inference for DLGMs (Rezende et al) First, start with the Variational bound on the Log-likelihood of the data: log p(v) F(V) = D KL [q(ζ) p(ζ)] E q [log p(v ζ, θ g )p(θ g )] For simplicity, let the Recognition model facotrize over over the L layers: q(ζ V, θ r ) = N n=1 l=1 L N (ζ n,l µ l (v n ), C l (v n )) (19) If a Gaussian prior and Gaussian Recognition model are used, then there are the explicit calculations: F(V) = E q [log p(v n h(η n ))] + 1 2κ θg 2 (20) n + 1 [ µ n,l 2 + Tr(C n,l ) log C n,l 1] 2 n,l

Scalable Inference for DLGMs (Rezende et al) Gradients for θ g are easy to estimate by Monte Carlo Integration Gradients for the recognition parameters θ r are easy to write down: µl F(v) = E q [ ζl log p(v h(ζ))] + µ l (21) Rl,i,j F(v) = 1 2 E q[ɛ l,j ζl,i log p(v h(ζ))] + 1 2 R l,i,j [TrC n,l log C n,l ] (22) C n,l is computed by backpropagation Sampling methods are used to estimate the gradients

Scalable Inference for DLGMs (Rezende et al) Maintaining a full covariance matrix in the recognition model is computationally difficult Can constrain to a diagonal covariance matrix or the precision to a diagonal plus rank-1 (use Sherman-Woodbury for computational efficiency) Sampling costs are O(L K 2 ), where K is the mean number of latent variables per layer

MNIST Model Data Generation Stochastic Backpropagation in DLGMs (a) Left: Training data. Middle: Sampled pixel probabilities. Right: Model samples ( Figure 3. Performance on the MNIST dataset. For the visualisation, each colour corresponds to Figure: From Rezende et al. by: David Carlson (a)review NORB (b) CIFAR

Combined Binarized MNIST Performance Model log p(v) NLGBN (Frey & Hinton, 1999) 95.8 Wake-Sleep (Dayan, 2000) 91.3 RBM (500 h, 25 CD steps) (Salakhutdinov & Murray, 2008) 86.34 DBN 2hl(Salakhutdinov & Hinton, 2009) 84.55 NADE 1hl (Larochelle & Murray, 2014) 88.86 EoNADE 2hl (Uria et al, 2014) 85.10 DLGM diagonal covariance (Rezende et al, 2014) 87.30 DLGM rank-one covariance (Rezende et al, 2014) 86.60 NVIL 200 (Mnih and Gregor, 2014) 103.8 NVIL 500 (Mnih and Gregor, 2014) 104.4 NVIL 200-200-200 (Mnih and Gregor, 2014) 94.5 NVIL 200-200-500 (Mnih and Gregor, 2014) 96.0 DARN n h = 500 (Gregor et al, 2014) 84.13