Causal Modeling with Generative Neural Networks Michele Sebag TAO, CNRS INRIA LRI Université Paris-Sud Joint work: D. Kalainathan, O. Goudet, I. Guyon, M. Hajaiej, A. Decelle, C. Furtlehner https://arxiv.org/abs/1709.05321 Credit for slides: Yann LeCun Leiden Sept. 2017 1 / 27
Motivation State of art Causal Generative Neural Nets Naive ML Approach to SW 2 / 27
ML: discriminative or generative modelling usually iid samples P(X, Y ) Given a training set E = {(xi, yi ), xi IRd, i [[1, n]]} Find I I b X ) Supervised learning: h : X 7 Y or P(Y b,y) Generative model P(X Predictive modelling might be based on correlations If umbrellas in the street, Then it rains 3 / 27
The big data promise: ML models will expectedly support interventions: health and nutrition education economics/management climate Intervention Pearl 2009 Intervention do(x = x) forces variables X to value x Direct cause X i X j P Xj do(x i =x,x \ij =c) P Xj do(x i =x,x \ij =c) Example C: Cancer, S : Smoking, G : Genetic factors P(C do{s = 0, G = 0}) P(C do{s = 1, G = 0}) 4 / 27
Correlations do not support interventions Causal models are needed to support interventions 5 / 27
Why is this relevant to space weather? Causal models support understanding Causal models are more robust Given observations drawn after P(X ), P(Y X ), Find P(Y X ) that minimizes IE x P(X ) [arg max y ] P(y x) arg max P(y x) y e.g., to concept drift But P(X ) in production might differ from P(X ) in training 6 / 27
Causal modelling, how Historically, based on interventions. However, often impossible climate unethical make people smoking too expensive e.g., in economics Machine Learning alternatives Observational data Statistical tests Learned models Prior knowledge / Assumptions / Constraints 7 / 27
Motivation State of art Causal Generative Neural Nets Naive ML Approach to SW 8 / 27
Functional Causal Models, a.k.a. Structural Equation Models X i = f i (Pa(X i ), E i ) Pa(X i ): Direct causes for X i All unobserved influences: noise variables E i X 1 = f 1(E 1) X 2 = f 2(X 1, E 2) X 3 = f 3(X 1, E 3) X 4 = f 4(E 4) X 5 = f 5(X 3, X 4, E 5) Tasks Finding the structure of the graph (no cycles) Finding functions (f i ) 9 / 27
Conducting a causal modelling study Milestones Testing bivariate independence (statistical tests) find edges Conditional independence prune the edges Full causal graph modelling orient the edges X Y ; Y Z X Z Y X Y Z Challenges Computational complexity tractable approximation Conditional independence: data hungry tests Assuming causal sufficiency can be relaxed 10 / 27
X Y independance Categorical variables P(X, Y ) =?P(X ).P(Y ) Entropy H(X ) = x p(x)log(p(x)) x: value taken by X, p(x) its frequency Mutual information M(X, Y ) = H(X ) + H(Y ) H(X, Y ) Others: χ 2, G-test Continuous variables t-test, z-test Hilbert-Schmidt Independence Criterion (HSIC) Gretton et al., 05 Cov(f, g) = IE x,y [f (x)g(y)] IE x[f (x)]ie y [g(y)] Given f : X IR and g : Y IR Cov(f, g) = 0 for all f, g iff X and Y are independent 11 / 27
An ML approach Guyon et al, 2014-2015 E = {(A i, B i, l i ), l i in {,, }} 12 / 27
Exploiting the distribution asymmetry Hoyer et al. 09; Mooij et al. 2016 True model with noise ɛ independent on X Y = X + ɛ Learn Y = f (X ), plot the residual Y f (X ) Learn X = g(y ), plot the residual X g(y ) 13 / 27
Exploiting the asymmetry, 2 Given A, B 14 / 27
Exploiting the asymmetry, 2 Given A, B, Learn A = f (B) B = g(a) Retain model with best fit: A B 15 / 27
Exploiting the asymmetry, 2 Given A, B, Learn A = f (B) B = g(a) Retain model with best fit: A B A: Altitude of city, B: Temperature 15 / 27
Find V-structure: A C and A C B Explaining away causes 16 / 27
Motivation State of art Causal Generative Neural Nets Naive ML Approach to SW 17 / 27
Auto-Encoders Training set Structure of Auto-Encoder E = {(x i ), x i IR d, i = 1... n} Minimization of Mean Squared Error (MSE) Minimize i x i x i 2 Output: z, a compressed representation of x 18 / 27
Stacked Auto-Encoders E = {(x i ), x i IR d, i = 1... n} Differences Several hidden layers Minimize MSE or cross-entropy loss Minimize i,j x i,j log ˆx i,j + (1 x i,j ) log (1 ˆx i,j ) 19 / 27
Variational Auto-Encoders Kingma et al. 13 E = {(x i ), x i IR d, i = 1... n} Difference Hidden layer: parameters of a distribution N (µ, σ 2 ) Distribution used to generate values z = µ + σ N (0, 1) 20 / 27
Variational Auto-Encoders Kingma et al. 13 E = {(x i ), x i IR d, i = 1... n} Difference Hidden layer: parameters of a distribution N (µ, σ 2 ) Distribution used to generate values z = µ + σ N (0, 1) 21 / 27
Causal Generative Neural Nets E = {(x i ), x i IR d, i = 1... n} Goudet et al. 17 E = {(x i ), x i IR d, i = 1... n } Train the generator to minimize the distance between original and generated data in IR d MMD(G) = 1 k(x n 2 i, x j ) + 1 k(x n 2 i, x j) 2 1 k(x nn i, x j) i,j k(x, z) = i i,j exp γ i d x z 2 γ i in {10 2... 10 2 } i,j 22 / 27
Relaxing the causal sufficiency assumption X 2 = f 2(E 2, E 2,3) X 3 = f 3(E 3, E 2,3, E 3,5) X 4 = f 4(E 4, E 4,5) X 5 = f 5(X 3, X 4, E 5, E 3,5, E 4,5) 23 / 27
Graph inference Results: Area under the precision/recall curve Algorithm G 2 G 3 G 4 Constraint-based PC-Gaussian 82.3 ±4 (87.8) 80.0 ±7 (89.2) 88.1 ±10 (95.7) PC-HSIC 93.4 ±3 (78.5) 93.0 ±4 (77.9) 98.9 ±2 (88.0) Score-based GES 75.3 ±7 (81.2) 73.6 ±7 (77.7) 69.3±11 (78.6) Pairwise orientation LiNGAM 64.4 ±4 (100) 71.1 ±1 (100) 71.6 ±7 (100) ANM 72.9 ±9 (100) 72.5 ±4 (100) 79.9 ±5 (100) Jarfo 69.9 ±9 (100) 87.3 ±3 (100) 88.5 ±5 (100) CGNN-Fourier 94.5 ±2 (100) 84.9 ±9 (100) 93.6 ±3 (100) CGNN-MMD 96.9 ±1 (100) 96.5 ±3 (100) 97.2 ±3 (100) Python framework available at :https://github.com/diviyan-kalainathan/causaldiscoverytoolbox Caveat: up to 50 variables 24 / 27
Motivation State of art Causal Generative Neural Nets Naive ML Approach to SW 25 / 27
Compact solar state representations
Principle 9
Image preprocessing 10
Autoencoders Dimensionality reduction 11
Autoencoders Dimensionality reduction Input and Output similarity 11
Autoencoders Dimensionality reduction Input and Output similarity Bottleneck 11
Autoencoders Dimensionality reduction Input and Output similarity Bottleneck 256x256 512 11
Autoencoders 512x512 512 12
Autoencoders 256x256 64 13
Variational Autoencoder Assumption on the latent space distribution 256x256 90 14
Autoencoders training Intermediate image size 15
Autoencoders training Intermediate image size Custom loss : loss = (ytrue y pred ) 2 (y true+ɛ) α + (ytrue y pred ) 2 (1 y true+ɛ) α 15
Results Autoencoder Conv Conv + Dense Conv + PCA Variational Reduction rate 1/128 1/1024 1/524 1/728 Visual similarity 16
Results Autoencoder Conv Conv + Dense Conv + PCA Variational Reduction rate 1/128 1/1024 1/524 1/728 Visual similarity Smoothness over time 16
Results Autoencoder Conv Conv + Dense Conv + PCA Variational Reduction rate 1/128 1/1024 1/524 1/728 Visual similarity Smoothness over time Classification for verification 16
Results Event precision recall accuracy F1-score Coronal hole 0.74 0.36 0.62 0.48 Lepping 0.90 0.51 0.77 0.65 Pseudo streamer 0.66 0.93 0.78 0.77 Strahl 0.55 0.98 0.73 0.70 * Random predictor performances are 0.625 for accuracy and 0.25 for the rest Only 8000 labeled images 17
Results Event precision recall accuracy F1-score Coronal hole 0.74 0.36 0.62 0.48 Lepping 0.90 0.51 0.77 0.65 Pseudo streamer 0.66 0.93 0.78 0.77 Strahl 0.55 0.98 0.73 0.70 * Random predictor performances are 0.625 for accuracy and 0.25 for the rest Only 8000 labeled images Time distribution 17
Results Event precision recall accuracy F1-score Coronal hole 0.74 0.36 0.62 0.48 Lepping 0.90 0.51 0.77 0.65 Pseudo streamer 0.66 0.93 0.78 0.77 Strahl 0.55 0.98 0.73 0.70 * Random predictor performances are 0.625 for accuracy and 0.25 for the rest Only 8000 labeled images Time distribution Prediction at L1 17
Results Event precision recall accuracy F1-score Coronal hole 0.74 0.36 0.62 0.48 Lepping 0.90 0.51 0.77 0.65 Pseudo streamer 0.66 0.93 0.78 0.77 Strahl 0.55 0.98 0.73 0.70 * Random predictor performances are 0.625 for accuracy and 0.25 for the rest Only 8000 labeled images Time distribution Prediction at L1 Low performances 17
Results Event precision recall accuracy F1-score Coronal hole 0.74 0.36 0.62 0.48 Lepping 0.90 0.51 0.77 0.65 Pseudo streamer 0.66 0.93 0.78 0.77 Strahl 0.55 0.98 0.73 0.70 * Random predictor performances are 0.625 for accuracy and 0.25 for the rest Only 8000 labeled images Time distribution Prediction at L1 Low performances Let s extract more information 17
Going further Classification of solar events More data Caveat: the train/test split Predicting data at L1 the propagation time from sun to L1 help needed! 26 / 27
Thanks Olivier Goudet, Diviyan Kalainathan, Isabelle Guyon, Aris Tritas Mhamed Hajaiej, Cyril Furtlehner, Aurélien Decelle 27 / 27