Multimodal Deep Learning for Predicting Survival from Breast Cancer Heather Couture Deep Learning Journal Club Nov. 16, 2016
Outline Background on tumor histology & genetic data Background on survival analysis Deep survival models Katzman et al., Deep Survival: A Deep Cox Proportional Hazards Network, 2016 Yousefi et al., Learning Genomic Representations to Predict Clinical Outcomes in Cancer, ICLR, 2016 Multimodal deep learning Wang et al., On Deep Multi-View Representation Learning, ICML, 2015 My work: predicting survival from tumor histology & genetics 2
Tissue Microarray microarray assembly gene expression tumor area identified on slide genetic subtype immunohistochemistry receptor status Sauter, Tissue microarrays in drug discovery, 2003 H&E histology histologic subtype, grade 3
Applications of Tumor Analysis Prognosis More favorable outcome if: few mitoses less nuclear pleomorphism (irregularity of nuclear size and shape) well differentiated (cell specialization) Abnormal cells http://www.gistsupport.org http://www.imaginis.com Personalized treatment Target tumors based on molecular analysis: http://www.breastpathology.info treatment 1 molecular analysis + image analysis? treatment 2 treatment 3 4
Motivation Improve predictions by using automated image analysis Faster More repeatable Capture properties that humans cannot Capture spatial properties in a way that genetics cannot vs Complement genetic analysis by integrating image and genetic data into a single model Predict survival to identify high and low risk patients 5
Approach Deep survival model with multimodal data + risk score image features gene expression 6
Survival Data Event time T, event indicator E E=1 (e.g., death) T is time to death E=0 (e.g., last contact with patient) T is time of last followup (right-censored) Predicting survival: Standard regression methods Must discard right-censored data Binary discriminative methods (e.g., death by time T) Must discard time to death Solution: proportional hazards model 7
Survival Analysis Survival function S(t) = Pr(T > t) Hazard function λ(t )=lim δ 0 Pr (t T <t +δ T t) δ Proportional hazards model λ(t x)=λ 0 (t )e h(x) Cox proportional hazards model λ 0 (t) baseline hazard function h(x) risk function x covariates h β (x)=β T x Cox partial likelihood maximize partial log likelihood L c (β)= i ϵ{i E i =1} e h β(x i ) e h β(x j ) j ϵ{j T j >T i } 8
Model Performance Concordance index pairwise agreement of risk predictions CI (β, X)= P I (i, j) P I (i, j)={ 1 if h(βt x i )>h(β T x j ) and T j >T i 0 otherwise P set of orderable pairs (X i,x j ) i.e., if E i =1 and E j =1 or E j =0 and T j > T i 9
Deep Survival Network Katzman et al., Deep Survival: A Deep Cox Proportional Hazards Network, 2016 Yousefi et al., Learning Genomic Representations to Predict Clinical Outcomes in Cancer, 2016 Approach: replace h(x) with a DNN risk score h θ (x i ) network weights θ input features x i Cost function: Cox partial log likelihood L(θ)= iϵ{i E i =1} h θ (x i ) log e h θ(x j ) j ϵ{j T j >T i } 10
Experiments Katzman et al., Deep Survival: A Deep Cox Proportional Hazards Network, 2016 Worcester Heart Attack Study 1638 observations 5 features (age, sex, BMI, left heart failure complications, order of MI) Linear Cox regression C-index: 0.669 (95% CI: 0.666-0.671) DeepSurv C-index: 0.779 (95% CI: 0.777-0.782) Molecular Taxonomy of Breast Cancer 1981 patients expression level for 14 manually selected genes clinical features: age, number of positive nodes, tumor size, receptor status, treatment Linear Cox regression C-index: 0.688 (95% CI: 0.686-0.690) Deep Surv C-index: 0.695 (95% CI: 0.693-0.697) 11
Experiments Yousefi et al., Learning Genomic Representations to Predict Clinical Outcomes in Cancer, ICLR, 2016 TCGA brain tumors 628 samples, 183 genomic features 10 random sets: 70% training, 30% testing 2 fully connected layers of 250 hidden units each 12
Multimodal Deep Learning Wang et al., On Deep Multi-View Representation Learning, ICML, 2015 Access to multiple unlabeled views of data for representation learning but only one view at test time Examples Audio + video Images + text Parallel text in two languages Words + context Two approaches Canonical correlation analysis (CCA) Learn features in two views that are maximally correlated Autoencoder Learn a representation that best reconstructs the inputs 13
Multimodal Data Given (xi,y i ), i=1,,n Wish to learn f(xi ) and g(y i ) such that f(x i ) and g(y i ) are highly correlated and/or Possible to reconstruct y i from x i through f(x i ) and vice versa 14
Canonical Correlation Analysis (CCA) Find projections u and v such that the data are maximally correlated u T Σ xy v (u, v)=argmax u, v maximize: u T Σ xy v corr(u T X, v T Y )=argmax u,v Constrain projections to have unit variance subject to: u T Σ xx u=v T Σ yy v=1 maximize: tr(u T Σ xy V ) subject to: U T Σ xx U =V T Σ yy V =I u T Σ xx u v T Σ yy v Find multiple pairs (u i, v i ) such that u i Σ xx u j = v i Σ yy v j = 0 for i < j U = [u 1,,u k ] and V = [v 1,,v k ] 15
Deep Canonical Correlation Analysis (DCCA) Andrew et al., Deep Canonical Correlation Analysis, ICML, 2013 maximize: 1 N tr (U T f ( X) g(y ) T V ) subject to: U T ( 1 N f (X )f ( X)T +r x I ) U =I features within modality are uncorrelated V T ( 1 N g(y ) g(y )T +r y I ) V =I u i T f ( X )g(y ) T v j =0 for i j r x, r y regularization parameters 16
Split Autoencoder (SplitAE) minimize: 1 N i=1 N ( x i p(f (x i )) 2 + y i q (f (x i )) 2 ) 17
Deep Canonically Correlated Autoencoder (DCCAE) Wang et al., On Deep Multi-View Representation Learning, ICML, 2015 minimize: 1 N tr (U T f ( X) g(y ) T V ) + subject to: U T ( 1 N f ( X)f ( X )T +r x I ) U =I λ N N ( x i p(f (x i )) 2 + y i q(g( y i )) 2 ) i=1 autoencoder regularization DCCA V T ( 1 N g(y )g(y )T +r y I ) V =I u i f ( X) g(y ) T v j =0 for i j 18
Correlated Autoencoder (CorrAE) Wang et al., On Deep Multi-View Representation Learning, ICML, 2015 minimize: 1 N tr(u T f ( X )g(y ) T V )+ λ N i=1 N ( x i p(f (x i )) 2 + y i q(g( y i )) 2 ) subject to: u i T f ( X )f ( X) T u i =v i T g(y ) g(y ) T v i =N, 1 i L Relaxation of DCCAE: feature dimensions within each view not constrained to be uncorrelated with each other 19
Experiments: Speech Recognition Wang et al., On Deep Multi-View Representation Learning, ICML, 2015 Recorded speech & articulatory measurements from 47 American English speakers 39 acoustic & 16 articulatory features from each of 7 frames Roughly 50k frames/speaker 1.43M frames Apply representation learning to frames Use original & learned features in standard HMM-based recognizer PER = phone error rates 20
Experiments: Multilingual Word Embeddings Wang et al., On Deep Multi-View Representation Learning, ICML, 2015 Learn representation of English words from pairs of English-German word embeddings 640D monolingual word vectors trained via LSA 36K English-German word pairs Evaluated on 180k English word embeddings Add projections of the two words in each bigram Cosine similarity between bigram pairs Order pairs by similarity Measure Spearman s correlation between model s and human s rankings AN: adjective-noun VN: verb-object 21
Experiments: Diagnosis of Schizophrenia Qi and Tejedor, Deep Multi-view Representation Learning for Multi-modal Features of Schizophrenia and Schizo-affective Disorder, ICASSP, 2016 Features from MRI: Source-based morphometric loading: 32 Functional network connectivity: 378 86 labeled, 119,748 unlabeled samples Train SVM on learned features 22
My Work: Predicting Survival from Breast Cancer SPECS: breast tumors tissue microarray 145 patients 2 cores per patient 512 image features from VGG16 14,570 gene expression levels
Multimodal Deep Survival Network risk score h θ (x i ) network weights θ input image features x i 1 input genetic features x i 2 L Cox (θ)= i ϵ{i E i =1} h θ (x i ) log e h θ(x j ) jϵ{j T j >T i } 24
Multiple Outputs for Regularization subtype risk score grade, receptor status, etc. network weights θ input image features x i 1 input genetic features x i 2 L(θ, X, E,T,Y subtype,y grade )= α L Cox (θ, X, E,T )+ L cross entropy (θ, X,Y subtype )+ L cross entropy (θ, X,Y grade ) 25
Implementation Details/Tricks Batch normalization Drop out L2 regularization 26
5-fold cross-validation x4 Results
Questions?