Learning features to compare distributions Arthur Gretton Gatsby Computational Neuroscience Unit, University College London NIPS 2016 Workshop on Adversarial Learning, Barcelona Spain 1/28
Goal of this talk Have: Two collections of samples X Y from unknown distributions P and Q. Goal: Learn distinguishing features that indicate how P and Q differ. 2/28
Goal of this talk Have: Two collections of samples X Y from unknown distributions P and Q. Goal: Learn distinguishing features that indicate how P and Q differ. 2/28
Divergences 3/28
Divergences 4/28
Divergences 5/28
Divergences 6/28
Divergences Sriperumbudur, Fukumizu, G, Schoelkopf, Lanckriet (2012) 7/28
Overview The Maximum mean discrepancy: How to compute and interpret the MMD How to train the MMD Application to troubleshooting GANs The ME test statistic: Informative, linear time features for comparing distributions How to learn these features TL;DR: Variance matters. 8/28
The maximum mean discrepancy Are P and Q different? P(x) Q(y) 6 4 2 0 2 4 6 9/28
Maximum mean discrepancy (on sample) 10/28
Maximum mean discrepancy (on sample) Observe X x 1 x n P Observe Y y 1 y n Q 10/28
Maximum mean discrepancy (on sample) Gaussian kernel on x i Gaussian kernel on y i 10/28
Maximum mean discrepancy (on sample) P v : mean embedding of P Q v : mean embedding of Q v P v 1 m m i 1 k x i v 10/28
Maximum mean discrepancy (on sample) P v : mean embedding of P Q v : mean embedding of Q witness v P v Q v v 10/28
Maximum mean discrepancy (on sample) MMD 2 witness v 2 1 1 k x i x j k y i y j n n 1 n n 1 i j i j 2 n 2 i j k x i y j 11/28
Overview Dogs P and fish Q example revisited Each entry is one of k dog i dog j, k dog i fish j, or k fish i fish j 12/28
Overview The maximum mean discrepancy: MMD 2 1 1 k dog n n 1 i dog j n n 1 i j i k fish i j fish j 2 n 2 i j k dog i fish j 13/28
Asymptotics of MMD The MMD: MMD 2 1 n n 1 i j k x i x j 1 n n 1 i j k y i y j 2 n 2 i j k x i but how to choose the kernel? y j 14/28
Asymptotics of MMD The MMD: MMD 2 1 n n 1 i j k x i x j 1 n n 1 i j k y i y j 2 n 2 i j k x i but how to choose the kernel? y j Perspective from statistical hypothesis testing: When P Q then MMD 2 close to zero. When P Q then MMD 2 far from zero Threshold c for MMD 2 gives false positive rate 14/28
Astatisticaltest 0.7 0.6 MMD density P=Q P Q Prob. of n MMD d 0.5 0.4 0.3 0.2 0.1 False negatives c α = 1 α quantile when P=Q 0 2 1 0 1 2 3 4 5 6 n MMD d 15/28
Astatisticaltest 0.7 0.6 MMD density P=Q P Q Prob. of n MMD d 0.5 0.4 0.3 0.2 0.1 False negatives c α = 1 α quantile when P=Q 0 2 1 0 1 2 3 4 5 6 n MMD d Best kernel gives lowest false negative rate (=highest power) 15/28
Astatisticaltest 0.7 0.6 MMD density P=Q P Q Prob. of n MMD d 0.5 0.4 0.3 0.2 0.1 False negatives c α = 1 α quantile when P=Q 0 2 1 0 1 2 3 4 5 6 n MMD d Best kernel gives lowest false negative rate (=highest power)... but can you train for this? 15/28
Asymptotics of MMD When P Q, statistic is asymptotically normal, MMD 2 MMD P Q V n P Q where MMD P Q is population MMD, andv n P Q O n 1. D 0 1 14 12 MMD distribution and Gaussian fit under H1 Empirical PDF Gaussian fit Prob. density 10 8 6 4 2 0 0 0.05 0.1 0.15 0.2 0.25 0.3 0.35 0.4 MMD 16/28
Asymptotics of MMD Where P Q, statistic has asymptotic distribution nmmd 2 l 1 l z 2 l 2 0.7 0.6 MMD density under H0 χ 2 sum Empirical PDF where i i x k x x centred i x dp x Prob. density 0.5 0.4 0.3 z l 0 2 i i d 0.2 0.1 0 2 1 0 1 2 3 4 5 6 n MMD 2 17/28
Optimizing test power The power of our test (Pr 1 denotes probability under P Q): Pr 1 nmmd 2 c 18/28
Optimizing test power The power of our test (Pr 1 denotes probability under P Q): Pr 1 nmmd 2 c 1 n c V n P Q MMD 2 P Q V n P Q where is the CDF of the standard normal distribution. c is an estimate of c test threshold. 18/28
Optimizing test power The power of our test (Pr 1 denotes probability under P Q): Pr 1 nmmd 2 c 1 n c V n P Q MMD 2 P Q V n P Q O n 3 2 O n 1 2 First term asymptotically negligible! 18/28
Optimizing test power The power of our test (Pr 1 denotes probability under P Q): Pr 1 nmmd 2 c 1 n c V n P Q MMD 2 P Q V n P Q To maximize test power, maximize MMD 2 P Q V n P Q (Sutherland, Tung, Strathmann, De, Ramdas, Smola, G., in review for ICLR 2017) Code: github.com/dougalsutherland/opt-mmd 18/28
Troubleshooting for generative adversarial networks MNIST samples Samples from a GAN 19/28
Troubleshooting for generative adversarial networks MNIST samples ARD map Samples from a GAN Power for optimzed ARD kernel: 1.00 at 0 01 Power for optimized RBF kernel: 0.57 at 0 01 19/28
Benchmarking generative adversarial networks 20/28
The ME statistic and test 21/28
Distinguishing Feature(s) P v : mean embedding of P Q v : mean embedding of Q witness v P v Q v v 22/28
Distinguishing Feature(s) witness 2 v Take square of witness (only worry about amplitude) 23/28
Distinguishing Feature(s) New test statistic: witness 2 at a single v ; Linear time in number n of samples...but how to choose best feature v? 23/28
Distinguishing Feature(s) v Best feature = v that maximizes witness 2 v?? 23/28
Distinguishing Feature(s) witness 2 v Sample size n 3 24/28
Distinguishing Feature(s) Sample size n 50 24/28
Distinguishing Feature(s) Sample size n 500 24/28
Distinguishing Feature(s) Pwx) Qwy) wittess 2 wv) Population witness 2 function 24/28
Distinguishing Feature(s) Pwx) Qwy) wittess 2 wv) v? v? 24/28
Variance of witness function Variance at v = variance of X at v + variance of Y at v. ME Statistic: n v n witness2 v variance of v. 25/28
Variance of witness function Variance at v = variance of X at v + variance of Y at v. ME Statistic: n v n witness2 v variance of v. 25/28
Variance of witness function Variance at v = variance of X at v + variance of Y at v. ME Statistic: n v n witness2 v variance of v. Pwx) Qwy) wittess 2 wv) 25/28
Variance of witness function Variance at v = variance of X at v + variance of Y at v. ME Statistic: n v n witness2 v variance of v. wittess 2 wv) vsristce X wv) 25/28
Variance of witness function Variance at v = variance of X at v + variance of Y at v. ME Statistic: n v n witness2 v variance of v. wittess 2 wv) vsristce Y wv) 25/28
Variance of witness function Variance at v = variance of X at v + variance of Y at v. ME Statistic: n v n witness2 v variance of v. wittess 2 wv) vsristce of v 25/28
Variance of witness function Variance at v = variance of X at v + variance of Y at v. ME Statistic: n v n witness2 v variance of v. ˆλ n (v) Best location is v that maximizes n. v Improve performance using multiple locations v j J j 1 25/28
Distinguishing Positive/Negative Emotions happy neutral surprised 35 females and 35 males (Lundqvist et al., 1998). 48 34 1632 dimensions. Pixel features. Sample size: 402. afraid angry disgusted The proposed test achieves maximum test power in time O n. Informative features: differences at the nose, and smile lines. 26/28
Distinguishing Positive/Negative Emotions 5andRP feature happy neutral surprised afraid angry disgusted PRwer 1.0 0.5 0.0 + vs. - The proposed test achieves maximum test power in time O n. Informative features: differences at the nose, and smile lines. 26/28
Distinguishing Positive/Negative Emotions 5andRP feature PrRpRsed happy neutral surprised afraid angry disgusted PRwer 1.0 0.5 0.0 + vs. - The proposed test achieves maximum test power in time O n. Informative features: differences at the nose, and smile lines. 26/28
Distinguishing Positive/Negative Emotions happy neutral surprised afraid angry disgusted PRwer 1.0 0.5 0.0 5DndRP fedture PrRpRsed 00D (quddrdtic tipe) + vs. - The proposed test achieves maximum test power in time O n. Informative features: differences at the nose, and smile lines. 26/28
Distinguishing Positive/Negative Emotions happy neutral surprised afraid angry disgusted Learned feature The proposed test achieves maximum test power in time O n. Informative features: differences at the nose, and smile lines. 26/28
Distinguishing Positive/Negative Emotions happy neutral surprised afraid angry disgusted Learned feature The proposed test achieves maximum test power in time O n. Informative features: differences at the nose, and smile lines. Code: https://github.com/wittawatj/interpretable-test 26/28
Final thoughts Witness function approaches: Diversity of samples: MMD test uses pairwise similarities between all samples ME test uses similarities to J reference features Disjoint support of generator/data distributions Witness function is smooth Other discriminator heuristics: Diversity of samples by minibatch heuristic (add as feature distances to neighbour samples) Salimans et al. (2016) Disjoint support treated by adding noise to blur images Arjovsky and Bottou (2016), Huszar (2016) 27/28
Co-authors Students and postdocs: Kacper Chwialkowski (at Voleon) Wittawat Jitkrittum Heiko Strathmann Dougal Sutherland Collaborators Kenji Fukumizu Krikamol Muandet Bernhard Schoelkopf Bharath Sriperumbudur Zoltan Szabo Questions? 28/28
Testing against a probabilistic model 29/28
Statistical model criticism MMD P Q f 2 sup f 1 E Q f E p f 0.4 0.3 0.2 0.1-4 -2 2 4-0.1 p(x) q(x) f * (x) -0.2-0.3 f x is the witness function Can we compute MMD with samples from Q and a model P? Problem: usualy can t compute E p f in closed form. 30/28
Stein idea To get rid of E p f in we define the Stein operator sup E q f f 1 E p f T p f x f f x log p Then E P T P f 0 subject to appropriate boundary conditions. (Oates, Girolami, Chopin, 2016) 31/28
Maximum Stein Discrepancy Stein operator T p f x f f x log p Maximum Stein Discrepancy (MSD) MSD p q sup E q T p g g 1 E p T p g 32/28
Maximum Stein Discrepancy Stein operator T p f x f f x log p Maximum Stein Discrepancy (MSD) MSD p q sup E q T p g g 1 E p T p g 32/28
Maximum Stein Discrepancy Stein operator T p f x f f x log p Maximum Stein Discrepancy (MSD) MSD p q sup E q T p g g 1 E p T p g sup E q T p g g 1 32/28
Maximum Stein Discrepancy Stein operator T p f x f f x log p Maximum Stein Discrepancy (MSD) MSD p q sup E q T p g g 1 E p T p g sup E q T p g g 1 0.4 0.2-4 -2 2 4-0.2-0.4 p(x) q(x) g * (x) -0.6 32/28
Maximum Stein Discrepancy Stein operator T p f x f f x log p Maximum Stein Discrepancy (MSD) MSD p q sup E q T p g g 1 E p T p g sup E q T p g g 1 0.4 0.3 0.2 p(x) q(x) 0.1 g * (x) -4-2 2 4 32/28
Maximum stein discrepancy Closed-form expression for MSD: given Z Z Strathmann, G., 2016) (Liu, Lee, Jordan 2016) q, then(chwialkowski, where MSD p q E q h p Z Z h p x y x log p x x log p y k x y y log p y x k x y x log p x yk x y x yk x y and k is RKHS kernel for Only depends on kernel and x log p x. Do not need to normalize p, or sample from it. 33/28
Statistical model criticism 3 Solar activity (normalised) 2 1 0 1 2 1600 1700 1800 1900 2000 Year Test the hypothesis that a Gaussian process model, learned from data, is a good fit for the test data (example from Lloyd and Ghahramani, 2015) Code: https://github.com/karlnapf/kernel_goodness_of_fit 34/28
Statistical model criticism 0.030 0.025 V n test Bootstrapped B n Frequency 0.020 0.015 0.010 0.005 0.000 0 50 100 150 200 250 300 V n Test the hypothesis that a Gaussian process model, learned from data, is a good fit for the test data 35/28