GANS for Sequences of Discrete Elements with the Gumbel-softmax Distribution

Similar documents
Discrete Latent Variable Models

y(x n, w) t n 2. (1)

Memory-Augmented Attention Model for Scene Text Recognition

Multi-layer Neural Networks

Need for Deep Networks Perceptron. Can only model linear functions. Kernel Machines. Non-linearity provided by kernels

Artificial Neural Networks D B M G. Data Base and Data Mining Group of Politecnico di Torino. Elena Baralis. Politecnico di Torino

Black-box α-divergence Minimization

CSC321 Lecture 10 Training RNNs

Deep Learning Recurrent Networks 2/28/2018

Machine Learning for Large-Scale Data Analysis and Decision Making A. Neural Networks Week #6

CSC321 Lecture 16: ResNets and Attention

Deep Feedforward Networks. Seung-Hoon Na Chonbuk National University

Generative adversarial networks

CS 229 Project Final Report: Reinforcement Learning for Neural Network Architecture Category : Theory & Reinforcement Learning

Deep Learning Sequence to Sequence models: Attention Models. 17 March 2018

Multilayer Perceptron

A Tutorial On Backward Propagation Through Time (BPTT) In The Gated Recurrent Unit (GRU) RNN

Neural Networks: Backpropagation

CSC 578 Neural Networks and Deep Learning

6.036: Midterm, Spring Solutions

Deep Feedforward Networks

Google s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation

Learning Recurrent Neural Networks with Hessian-Free Optimization: Supplementary Materials

Modeling Documents with a Deep Boltzmann Machine

Lecture 11 Recurrent Neural Networks I

Linear Models for Regression CS534

Plan. Perceptron Linear discriminant. Associative memories Hopfield networks Chaotic networks. Multilayer perceptron Backpropagation

Deep Learning II: Momentum & Adaptive Step Size

Computer Intensive Methods in Mathematical Statistics

CSC 411: Lecture 04: Logistic Regression

Bayesian Semi-supervised Learning with Deep Generative Models

Generating Sequences with Recurrent Neural Networks

Neural Architectures for Image, Language, and Speech Processing

Sequence Modelling with Features: Linear-Chain Conditional Random Fields. COMP-599 Oct 6, 2015

On the use of Long-Short Term Memory neural networks for time series prediction

Need for Deep Networks Perceptron. Can only model linear functions. Kernel Machines. Non-linearity provided by kernels

Neural Networks Language Models

Expectation Propagation for Approximate Bayesian Inference

Introduction to Neural Networks

Deep Learning. Recurrent Neural Network (RNNs) Ali Ghodsi. October 23, Slides are partially based on Book in preparation, Deep Learning

HIGH PERFORMANCE CTC TRAINING FOR END-TO-END SPEECH RECOGNITION ON GPU

Deep Latent-Variable Models of Natural Language

COMPARING FIXED AND ADAPTIVE COMPUTATION TIME FOR RE-

Recurrent Neural Networks with Flexible Gates using Kernel Activation Functions

EE-559 Deep learning LSTM and GRU

Deep Learning. Basics and Intuition. Constantin Gonzalez Principal Solutions Architect, Amazon Web Services

Neural Network Training

Lecture 11 Recurrent Neural Networks I

Statistical Machine Learning from Data

Linear Models for Regression CS534

Modelling Time Series with Neural Networks. Volker Tresp Summer 2017

Long-Short Term Memory

UNSUPERVISED LEARNING

Sequence Modeling with Neural Networks

Recurrent neural networks

Reading Group on Deep Learning Session 1

CSC321 Lecture 15: Exploding and Vanishing Gradients

Linear Models for Classification

Contents. (75pts) COS495 Midterm. (15pts) Short answers

arxiv: v3 [cs.lg] 14 Jan 2018

Logistic Regression. Stochastic Gradient Descent

Normalization Techniques

Recurrent Neural Network Training with Preconditioned Stochastic Gradient Descent

Convolutional Neural Networks

arxiv: v1 [cs.cl] 21 May 2017

CSC321 Lecture 8: Optimization

Name: Student number:

RECURRENT NETWORKS I. Philipp Krähenbühl

Machine Learning CS 4900/5900. Lecture 03. Razvan C. Bunescu School of Electrical Engineering and Computer Science

REINFORCE Framework for Stochastic Policy Optimization and its use in Deep Learning

Vasil Khalidov & Miles Hansard. C.M. Bishop s PRML: Chapter 5; Neural Networks

ECE276A: Sensing & Estimation in Robotics Lecture 10: Gaussian Mixture and Particle Filtering

Stephen Scott.

Statistical Machine Learning (BE4M33SSU) Lecture 5: Artificial Neural Networks

Computational statistics

Conditional Language modeling with attention

Introduction to Deep Neural Networks

PILCO: A Model-Based and Data-Efficient Approach to Policy Search

Empirical Risk Minimization

Introduction to Gaussian Process

Warm up: risk prediction with logistic regression

Gaussian Process Vine Copulas for Multivariate Dependence

Introduction Neural Networks - Architecture Network Training Small Example - ZIP Codes Summary. Neural Networks - I. Henrik I Christensen

Deep Learning for Computer Vision

An overview of deep learning methods for genomics

Comparison of Modern Stochastic Optimization Algorithms

Privacy-Preserving Adversarial Networks

Neural networks COMS 4771

Convolutional Neural Networks II. Slides from Dr. Vlad Morariu

Introduction to Convolutional Neural Networks (CNNs)

Multiclass Logistic Regression

Introduction to Machine Learning

Tracking the World State with Recurrent Entity Networks

Recurrent Neural Network

Recurrent Neural Networks (RNNs) Lecture 9 - Networks for Sequential Data RNNs & LSTMs. RNN with no outputs. RNN with no outputs

CS 1674: Intro to Computer Vision. Final Review. Prof. Adriana Kovashka University of Pittsburgh December 7, 2016

Variational Autoencoders

Midterm Review CS 7301: Advanced Machine Learning. Vibhav Gogate The University of Texas at Dallas

EXAM IN STATISTICAL MACHINE LEARNING STATISTISK MASKININLÄRNING

Transcription:

GANS for Sequences of Discrete Elements with the Gumbel-softmax Distribution Matt Kusner, Jose Miguel Hernandez-Lobato Satya Krishna Gorti University of Toronto Satya Krishna Gorti CSC2547 1 / 16

Introduction In the GAN methodology, we have a Generator network G, and a Discriminator network D. The Discriminator is used to predict whether a data instance is synthetic or real. The Generator G is trained to confuse G by training high quality data. GANs are trained by propagating gradients backward from D to G. This is only feasible if generated data is continuous. Satya Krishna Gorti CSC2547 2 / 16

Discrete data, encoded as one-hot representation, can be sampled from a categorical distribution, however this sampling process is not differentiable. We can however obtain a differentiable approximation from the Gumbel-Softmax distribution. Satya Krishna Gorti CSC2547 3 / 16

The Gumbel distribution trick The CDF of a standard Gumbel distribution is given by CDF gumbel (g i ) = exp( exp( g i )) A random variable g i is said to have a Gumbel distribution if g i = log ( log (U)) where U Uniform[0, 1] (Inverse transform sampling) The Gumbel-Max trick is a method to sample from a categorical distribution Cat(α 1, α 2,...α K ) Satya Krishna Gorti CSC2547 4 / 16

let set {g k } k K be an i.i.d sequence of standard Gumbel random variables. The trick is based on the observation that k max = arg max k (log α k + g k ) follows the desired categorical distribution. (Proof at blog post by Ryan Adams 1 ) Hence a discrete variable sampled from the categorical distribution can be written as: y = onehot(arg max k (log α k + g k )) Therefore, procedure to sample from a categorical distribution is: draw Gumbel noise by transforming uniform random samples. add it to log α k (α k can be unnormalized probability) take value of k that produces the maximum. 1 https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discretedistributions/ Satya Krishna Gorti CSC2547 5 / 16

Gumbel-Softmax relaxation trick Since, arg max operator is not continuous, we need a differentiable approximation. The Gumbel-softmax trick is to approximate the operator with a softmax transformation. We approximate y with: y k = exp((log α k+g k )/τ) K i=1 exp((log α i +g i )/τ) Satya Krishna Gorti CSC2547 6 / 16

Example problem The GAN is made to approximate sequence given by the below CFG. S x S + S S S S S S/S, where x is the terminal. Examples of valid strings: x x + x x/x x + x + x, x + x, x Satya Krishna Gorti CSC2547 7 / 16

RNN for discrete sequences The generative model is based on an LSTM RNN Figure: A classic LSTM RNN model during the prediction phase Satya Krishna Gorti CSC2547 8 / 16

The LSTM RNN is trained to predict a hidden-state vector h at every time step, the softmax operator is then applied to h which gives us a distribution over all possible generated characters (in our case: x, +,, /, ) LSTM model is trained by matching the softmax distribution to a one-hot encoding of the input data via Maximum Likelihood Estimation (MLE). We need to build a generative model for discrete sequences, which is accomplished by sampling through the LSTM using GAN approach Satya Krishna Gorti CSC2547 9 / 16

In this case, both Generator G, and Discriminator D are LSTMs with parameters Θ and Φ respectively. The Generator network G takes as input a sample-pair which effectively replaces the initial of the hidden and memory cell states. Figure: Generator network for discrete sequences Satya Krishna Gorti CSC2547 10 / 16

Figure: The GAN network Satya Krishna Gorti CSC2547 11 / 16

Our aim while training the GAN is to minimize differentiable loss functions for G and D to update Θ and Φ. Since we already know that sampling points from G from the categorical distribution given by the LSTM is not differentiable. We can now use the Gumbel-softmax distribution to sample from LSTM and optimize Θ and Φ using back-propagation. i.e. Instead of we instead use, y = onehot(arg max(h i + g i )) (1) k y = softmax((h + g)(1/τ)) (2) when τ 0, we have same distribution of that generated by (1) when τ, the samples are always from uniform probability vector. Satya Krishna Gorti CSC2547 12 / 16

Algorithm Satya Krishna Gorti CSC2547 13 / 16

Experimenation 5000 samples with maximum length of 12 characters were generated from CFG defined previously. (All sequences with less than 12 characters were padded with spaces) Trained G and D for 20,000 mini-batch iterations. Linearly annealed temperature coeff from τ = 5 to τ = 1. Satya Krishna Gorti CSC2547 14 / 16

Results Satya Krishna Gorti CSC2547 15 / 16

Thank you Satya Krishna Gorti CSC2547 16 / 16