AdaGAN: Boosting Generative Models Ilya Tolstikhin ilya@tuebingen.mpg.de joint work with Gelly 2, Bousquet 2, Simon-Gabriel 1, Schölkopf 1 1 MPI for Intelligent Systems 2 Google Brain
Radford et al., 2015)
Radford et al., 2015)
Ngyuen et al., 2016)
Youtube: Yota Ishida)
Motivation Generative Adversarial Networks, Variational Autoencoders, and other unsupervised representation learning/generative modelling techniques ROCK! Dozens of papers being submitted, discussed, published,... Among papers submitted to ICLR2017 15 have titles matching Generative Adversarial*. NIPS2016 Tutorial on GANs. GAN Workshop at NIPS2016 was the second-largest. That s all very exciting. However: GANs are very hard to train. The field is, apparently, in its trials-and-errors phase. People try all kinds of heuristic hacks. Sometimes they work, sometimes not... Very few papers manage to go beyond heuristics and provide a deeper conceptual understanding of what s going on.
AdaGAN: adaptive mixtures of GANs
AdaGAN: adaptive mixtures of GANs
Contents 1. Generative Adversarial Networks and f-divergences GANs are trying to minimize specific f-divergence min D f P data, P model ). P model P At least, this is one way to explain what they are doing. 2. Minimizing f-divergences with mixtures of distributions We study theoretical properties of the following problem: ) T min D f P data, α t P t α T,P T P and show that it can be reduced to min P T P D f t=1 Pdata, P T ). *) 3. Back to GANs Now we use GANs to train P T in *), which results in AdaGAN.
Contents 1. Generative Adversarial Networks and f-divergences GANs are trying to minimize specific f-divergence min D f P data P model ). P model P At least, this is one way to explain what they are doing. 2. Minimizing f-divergences with mixtures of distributions We study theoretical properties of the following problem: ) min D T f P data α t P t α T,P T P and show that it can be reduced to min P T P D f t=1 Pdata P T ). *) 3. Back to GANs Now we use GANs to train P T in *), which results in AdaGAN.
f-gan: Adversarial training for f-divergence minimization For any probability distributions P and Q, f-divergence is ) dq D f Q P ) := f dp x) dp x), where f : 0, ) R is convex and satisfies f1) = 0. Properties of f-divergence: 1. D f P Q) 0 and D f P Q) = 0 when P = Q; 2. D f P Q) = D f Q P ) for f x) := xf1/x); 3. D f P Q) is symmetric when fx) = xf1/x) for all x; 4. D f P Q) = D g P Q) for any gx) = fx) + Cx 1); 5. Taking f 0 x) := fx) x 1)f 1) is a good choice. It is nonnegative, nonincreasing on 0, 1], and nondecreasing on [1, ).
f-gan: Adversarial training for f-divergence minimization For any probability distributions P and Q, f-divergence is ) dq D f Q P ) := f dp x) dp x), where f : 0, ) R is convex and satisfies f1) = 0. Examples of f-divergences: 1. Kullback-Leibler divergence for fx) = x log x; 2. Reverse Kullback-Leibler divergence for fx) = log x; 3. Total Variation distance for fx) = x 1 ; 4. Jensen-Shannon divergence for fx) = x + 1) log x+1 2 + x log x: JSP Q) = KL P P + Q ) + KL Q P + Q ). 2 2
f-gan: Adversarial training for f-divergence minimization Kullback-Leibler minimization: assume X 1,..., X n Q ) dq min KLQ P ) min log P P dp x) dqx) min logdp x))dqx) P max logdp x))dqx) P 1 N max logdp X i )). P N Reverse Kullback-Leibler minimization: ) dp min KLP Q) min log P P dq x) dp x) min HP ) logdqx))dp x) P max HP ) + logdqx))dp x). P i=1
f-gan: Adversarial training for f-divergence minimization Now we will follow Nowozin et al., NIPS2016). For any probability distributions P and Q over X, f-divergence is ) dq D f Q P ) := f dp x) dp x). Variational representation: a very neat result, saying that D f Q P ) = [ sup E X Q [gx)] E Y P f gy ) )], g : X domf ) where f x) := sup u x u fu) is a convex conjugate of f. D f can be computed by solving an optimization problem! The f-divergence minimization can be now written as: inf D f Q P ) inf P P sup [ E X Q [gx)] E Y P f gy ) )]. g :...
f-gan: Adversarial training for f-divergence minimization The f-divergence minimization can be now written as: inf D f Q P ) inf P P [ sup E X Q [gx)] E Y P f gy ) )]. g : X domf ) 1. Take P to be a distribution of G θ Z), where Z P Z is a latent noise and G θ is a deep net; 2. JS divergence corresponds to fx) = x + 1) log x+1 2 + x log x and f t) = log 2 e t) and thus domain of f is, log 2); 3. Take g = g f D ω, where g f v) = log 2 log1 + e v ) and D ω is a deep net check that gx) < log 2 for all x). Then up to additive 2 log 2 term we get: inf G θ sup D ω E X Q log 1 1 + e + E DωX) Z P Z log 1 ) 1 ) 1 + e Dω G θ Z) Compare to the original GAN objective Goodfellow et al., NIPS2014) inf G θ sup E X Pd [log D ω X)] + E Z PZ [log 1 D ω G θ Z)) ) ]. D ω
f-gan: Adversarial training for f-divergence minimization Then we get the original GAN objective Goodfellow et al., NIPS2014) inf G θ sup E X Pd [log D ω X)] + E Z PZ [log 1 D ω G θ Z)) ) ]. *) D ω Let s denote a distribution of G θ Z) when Z P Z with P θ. For any fixed generator G θ, the theoretical-optimal D is: D θx) = If we insert this back to *) we get inf G θ dp d x) dp d x) + dp θ x). 2 JSP d P θ ) log4).
f-gan: Adversarial training for f-divergence minimization The f-divergence minimization can be now written as: inf D f Q P ) inf P P This is GREAT, because: [ sup E X Q [gx)] E Y P f gy ) )]. g : X domf ) 1. Estimate E X Q [gx)] using an i.i.d. sample data) X 1,..., X N Q: E X Q [gx)] 1 N N gx i ); i=1 2. Often we don t know the analytical form of P, but can only sample from P. No problems! Sample i.i.d. Y 1,..., Y M P and write: E Y P [ f gy ) )] 1 M M f gy j ) ). j=1 This results in implicit generative modelling: using complex model distributions, which are analytically intractable, but easy to sample.
f-gan: Adversarial training for f-divergence minimization Note that there is still a difference between doing GAN: inf G θ sup E X Pd [log D ω X)] + E Z PZ [log 1 D ω G θ Z)) ) ] D ω and minimizing JSP data, P model ): because inf P model [ )] sup E X Pdata [gx)] E Y Pmodel f JS gy ), g :... 1. We replaced expectations with sample averages. Uniform lows of large numbers may not apply, as our function classes are huge; 2. Instead of taking supremum over all possible witness functions g we optimize over restricted class of DNNs which is still huge); 3. In practice we never optimize g/d ω to the end. This leads to vanishing gradients and other bad things.
Contents 1. Generative Adversarial Networks and f-divergences GANs are trying to minimize specific f-divergence min D f P data P model ). P model P At least, this is one way to explain what they are doing. 2. Minimizing f-divergences with mixtures of distributions We study theoretical properties of the following problem: ) min D T f P data α t P t α T,P T P and show that it can be reduced to min P T P D f t=1 Pdata P T ). *) 3. Back to GANs Now we use GANs to train P T in *), which results in AdaGAN.
Minimizing f-divergences with mixtures of distributions Assumption: Let s assume we know more or less how to solve min D f Q P ) Q G for any f and any P, given we have an i.i.d. sample from P. New problem: Now we want to approximate an unknown P d in D f using a mixture model of the form Greedy strategy: P T model := T α i P i. i=1 1. Train P 1 to minimize D f P 1 P d ), set α 1 = 1, and get P 1 model = P 1; 2. Assume we trained P 1,..., P t with weights α 1,..., α t. We are going to train one more component Q, choose weight β and update P t+1 model = t 1 β)α i P i + βq. i=1
Minimizing f-divergences with mixtures of distributions 1. Train P 1 to minimize D f P 1 P d ), set α 1 = 1, and get P 1 model = P 1; 2. Assume we trained P 1,..., P t with weights α 1,..., α t. We are going to train one more component Q, choose weight β and update P t+1 model = t 1 β)α i P i + βq. i=1 The goal: we want to choose β [0, 1] and Q greedily, so as to minimize where we denoted P g := P t model. D f 1 β)p g + βq P d ), A modest goal: find β [0, 1] and Q, such that, for c < 1: D f 1 β)p g + βq P d ) c D f P g P d ). Inefficient in practice: as t we expect β 0. Only β fraction of points gets sampled from Q gets harder and harder to tune Q.
Minimizing f-divergences with mixtures of distributions D f 1 β)p g + βq P d ) *) Inefficient in practice: as t we expect β 0. Only β fraction of points gets sampled from Q gets harder and harder to tune Q. Workaround: instead optimize an upper bound on *), which contains a term D f Q, Q 0 ) for some Q 0 : Lemma Given P d, P g, and β [0, 1], for any Q and any R with β dr dp d ) βdq R) + 1 β)d f P g P ) d βr. 1 β If furthermore D f is a Hilbertian metric, then, for any R, we have ) βd f Q R) + D f 1 β)p g + βr P d ). JS-divergence, TV, Hellinger, and others are Hilbertian metrics.
Minimizing f-divergences with mixtures of distributions D f 1 β)p g + βq P d ) *) Lemma Given P d, P g, and β [0, 1], for any Q and any R with β dr dp d ) βdq R) + 1 β)d f P g P ) d βr. 1 β If furthermore D f is a Hilbertian metric, then, for any R, we have ) 2 βd f Q R) + D f 1 β)p g + βr P d )). Recipe: choose R to minimize the right-most terms and then minimize D f Q R ), which we can do given we have an i.i.d. sample from R.
Minimizing f-divergences with mixtures of distributions First surrogate upper bound Lemma Given P d, P g, and β [0, 1], for any Q and R, if D f is Hilbertian ) 2 βd f Q R) + D f 1 β)p g + βr P d )). The optimal R is given in the next result: Theorem Given P d, P g, and β 0, 1], the solution to min D f 1 β)p g + βr P d ), R P where P is a class of all probability distributions, has the density dr βx) = 1 β λ dp d x) 1 β)dp g x)) + for some unique λ satisfying dr β = 1. Also, λ = 1 if and only if P d 1 β)dp g > dp d ) = 0, which is equivalent to βdq β = dp d 1 β)dp g.
Minimizing f-divergences with mixtures of distributions Second surrogate upper bound Lemma Given P d, P g, and β [0, 1], for any Q and any R with β dr dp d D f 1 β)p g + βq P d ) βdq R) + 1 β)d f P g P ) d βr. 1 β The optimal R is given in the next result: Theorem Given P d, P g, and β 0, 1], assume P d dp g = 0) < β. Then the solution to min D f P g P ) d βr R:βdR dp d 1 β is given by the density dr β x) = 1 β dpd x) λ 1 β)dp g x) ) + for a unique λ 1 satisfying dr β = 1.
Minimizing f-divergences with mixtures of distributions Two different optimal R : dr βx) = 1 β λ dp d x) 1 β)dp g x)) + dr β x) = 1 β dpd x) λ 1 β)dp g x) ) + Very symmetrically looking, but no obvious connections found... Both do not depend on f! λ and λ are implicitly defined via fixed-point equations. drβ x) is also a solution to the original problem, that is of min D f 1 β)p g + βq P d ). Q R We see that setting β = 1 is the best possible choice in theory. So we ll have to use various heuristics, including uniform, constant,...
Minimizing f-divergences with mixtures of distributions D f 1 β)p g + βq P d ) *) Lemma Given P d, P g, and β [0, 1], for any Q and any R with β dr dp d ) βdq R) + 1 β)d f P g P ) d βr. 1 β If furthermore D f is a Hilbertian metric, then, for any R, we have ) 2 βd f Q R) + D f 1 β)p g + βr P d )). Recipe: choose R to minimize the right-most terms and then minimize D f Q R ), which we can do given we have an i.i.d. sample from R.
Minimizing f-divergences with mixtures of distributions Recipe: choose R to minimize the right-most terms in the surrogate upper bounds and then minimize the first term D f Q R ). We found optimal R distributions. Let s now assume we are super powerful and managed to find Q with D f Q R ) = 0. In other words, we are going to update our model: with Q either R β or R β. P g 1 β)p g + βq How well are we doing in terms of the original objective Theorem We have: D f 1 β)pg +βr β D f 1 β)p g + βq P d )? Pd ) Df 1 β)pg +βp d Pd ) 1 β)df P g P d ) and D f 1 β)pg + βr β Pd ) 1 β)df P g P d ).
Minimizing f-divergences with mixtures of distributions Other highlights: More careful convergence analysis. In particular we provide a necessary and sufficient condition for the iterative process to converge in a finite number of steps: there exists M > 0 such that P d 1 β)dp 1 > MdP d ) = 0. In other words, P 1 P d In practice we don t attain D f Q R ) = 0. We prove that given D f Q R β) γd d P g P d ) or D f Q R β ) γd dp g P d ) we achieve D f 1 β)p g + βq P d ) c D f P g P d ) with c = cγ, β) < 1. This is kind of similar to weak learnability in AdaBoost.
Contents 1. Generative Adversarial Networks and f-divergences GANs are trying to minimize specific f-divergence min D f P data P model ). P model P At least, this is one way to explain what they are doing. 2. Minimizing f-divergences with mixtures of distributions We study theoretical properties of the following problem: ) min D T f P data α t P t α T,P T P and show that it can be reduced to min P T P D f t=1 Pdata P T ). *) 3. Back to GANs Now we use GANs to train P T in *), which results in AdaGAN.
AdaGAN: adaptive mixtures of GANs We want to find Q so as to minimize D f Q Rβ ), where dr βx) = 1 β λ dp d x) 1 β)dp g x)) +. We would like to use GAN for that. How do we sample from R β? Let s rewrite: dr βx) = dp dx) β λ 1 β) dp ) g x). dp d + Recall: E Pd [log DX)] + E Pg [log 1 DY ) ) ] is maximized by: D x) = dp d x) dp d x) + dp g x) dp g x) = 1 D x) dp d D := hx). x) Each training sample end up with a weight w i = p i λ 1 β)hx i ) ) β + 1 λ 1 β)hx i ) ) Nβ +. **) Finally, we compute λ using **).
AdaGAN: adaptive mixtures of GANs Each training sample end up with a weight w i = p i λ 1 β)hx i ) ) β 1 + Nβ λ 1 β) 1 Dx i) Dx i ) Now we run a usual GAN with an importance-weighted sample. Dx i ) 1: our current model P g is missing this point. Big w i. Dx i ) 0: x i is covered by P g nicely. Zero w i. The better our model gets, the more data points will be zeroed. This leads to an aggressive adaptive iterative procedure. β can be also chosen so as to force #{w i : w i > 0} r N. ). +
AdaGAN: adaptive mixtures of GANs Coverage metric C: mass of the true data covered by the model: C := P data dp model X) > t) where P model dp model > t) = 0.95. Experimental setting: Mixture of 5 two-dimensional Gaussians. Best of T GANs Bagging AdaGAN Blue points: median over 35 random runs; Green error bars: 90% intervals. Not the confidence intervals for medians!
AdaGAN: adaptive mixtures of GANs Coverage metric C: mass of the true data covered by the model: C := P data dp model X) > t) where P model dp model > t) = 0.95. TopKLast r: β t is chosen to zero out r N of the training; Boosted: AdaGAN with β t = 1/t.
AdaGAN: adaptive mixtures of GANs Coverage metric C: mass of the true data covered by the model: C := P data dp model X) > t) where P model dp model > t) = 0.95. Experimental setting: Mixture of 10 two-dimensional Gaussians.
AdaGAN: adaptive mixtures of GANs
Conclusions Pros: It seems that AdaGAN is solving the missing modes problem; AdaGAN is compatible with any implementation of GAN and more generally with any other implicit generative modelling techniques; Still easy to sample from; Easy to implement and great freedom with heuristics; Theoretically grounded! Cons: We loose the nice manifold structure in the latent space; Increased computational complexity; Individual GANs are still very unstable.
Thank you! AdaGAN: Boosting Generative Models. Tolstikhin, Gelly, Bousquet, Simon-Gabriel, Schoelkopf, NIPS 2017. https://github.com/tolstikhin/adagan
Minimizing the optimal transport Optimal transport for a cost function cx, y): X X R + is W c P X, P G ) := inf E X,Y ) Γ[cX, Y )], Γ PX P X,Y P G ) If P G Y Z = z) = δ Gz) for all z Z, where G: Z X, we have W c P X, P G ) = [ )] inf E PX E QZ X) c X, GZ), Q : Q Z =P Z where Q Z is the marginal distribution of Z when X P X, Z QZ X).
Minimizing the optimal transport Optimal transport for a cost function cx, y): X X R + is W c P X, P G ) := inf E X,Y ) Γ[cX, Y )], Γ PX P X,Y P G ) If P G Y Z = z) = δ Gz) for all z Z, where G: Z X, we have W c P X, P G ) = [ )] inf E PX E QZ X) c X, GZ), Q : Q Z =P Z where Q Z is the marginal distribution of Z when X P X, Z QZ X).
Relaxing the constraint W c P X, P G ) = [ )] inf E PX E QZ X) c X, GZ), Q : Q Z =P Z Penalized Optimal Transport replaces the constraint with a penalty: POTP X, P G ) := inf Q E P X E QZ X) [ c X, GZ) )] + λ DQZ, P Z ) and uses the adversarial training in the Z space to approximate D. For the 2-Wasserstein distance cx, Y ) = X Y 2 2 POT recovers Adversarial Auto-Encoders Makhzani et al., 2016) For the 1-Wasserstein distance cx, Y ) = X Y 2 POT and WGAN are solving the same problem from the primal and dual forms respectively. Importantly, unlike VAE, POT does not force QZ X = x) to intersect for different x, which is known to lead to the blurriness. Bousquet et al., 2017)
Further details on GAN: the log trick inf G θ sup T ω E X Pd [log T ω X)] + E Z PZ [log 1 T ω G θ Z)) ) ] In practice verify this!) G θ is often trained to instead minimize E Z PZ [log T ω G θ Z)) ) ] The log trick ) One possible explanation: For any f r and P G the optimal T in the variational representation of D fr P X P G ) has the form ) T x) := f px x) r p G x) p X x) p G x) = f r ) 1 T x) ) 1. Approximate T T ω using the adversarial training with f r; 2. Express p X pg x) = f r ) 1 T w x) ) ; 3. Approximate the desired f-divergence D f P X P G ) using ) px x) f p G x)dx f f r X p G x) ) 1 T x) )) p G x)dx X Applying this argument with the JS f r and fx) = log1 + x 1 ) recovers the log trick. The corresponding f-divergence is much more mode-seeking than the reverse KL divergence.