for Sum-Product Networks Han Zhao 1, Tameem Adel 2, Geoff Gordon 1, Brandon Amos 1 Presented by: Han Zhao Carnegie Mellon University 1, University of Amsterdam 2 June. 20th, 2016 1 / 26
Outline Background Sum-Product Networks Variational Inference Collapsed Variational Inference Motivations and Challenges Efficient Marginalization Logarithmic Transformation Experiments Summary 2 / 26
Sum-Product Networks Definition A Sum-Product Network (SPN) is a Rooted directed acyclic graph of univariate distributions, sum nodes and product nodes. Value of a product node is the product of its children. Value of a sum node is the weighted sum of its children, where the weights are nonnegative. Value of the network is the value at the root. + w 1 w 2 w 3 I x1 I x1 I x2 I x2 3 / 26
Sum-Product Networks Mixture of Trees Each SPN can be decomposed as a mixture of trees: w1 w3 w2 = w 1 +w 2 +w 3 + + + + Ix1 I x1 Ix2 I x2 Ix1 Ix2 Ix1 I x2 I x1 I x2 Each tree is a product of univariate distributions. Number of mixture components is Ω(2 Depth ). Each network computes a positive polynomial (posynomial) function of model parameters: τ S n V root (x w) = p t (X i = x i ) w kj t=1 (k,j) T te i=1 4 / 26
Sum-Product Networks Bayesian Network Alternatively, each SPN S is equivalent to a Bayesian network B with bipartite structure. H + H1 + H2 H3 + + H4 + H1 H3 H H2 H4 X1 X2 X3 X4 X1 X3 X2 X4 Number of sum nodes in S = Number of hidden variables in B = Θ( S ). B = O(n S ) Number of observable variables in B = Number of variables modeled by S. Typically number of hidden variables number of observable variables. 5 / 26
Variational Inference Brief Introduction Bayesian Inference: p(w x) }{{} posterior Often intractable because of: No analytical solution. p(w) }{{} prior Expensive numerical integration. p(x w) }{{} likelihood General idea: find the best approximation in a tractable family of distributions Q: minimize q Q KL[q(w) p(w x)] Typical choice of approximation families: Mean-field, structured mean-field, etc. 6 / 26
Variational Inference Brief Introduction Variational method: Optimization-based, deterministic approach for approximate Bayesian inference. inf KL[q(w) p(w x)] sup E q [log p(w, x)] + H[q] q Q q Q Evidence Lower Bound L: log p(x) sup E q [log p(w, x)] + H[q] =: L q Q 7 / 26
Motivations and Challenges Bayesian inference algorithms for SPNs: Flexible at incorporating prior knowledge about the structure of SPNs. More robust to overfitting. 8 / 26
Motivations and Challenges W 1 W 2 W 3 W m H 1 H 2 H 3 H m X 1 X 2 X 3 X n D W Model parameters, global hidden variables. H Assignments of sum nodes, local hidden variables. X Observable variables. D Number of instances. Challenges for standard VB: Large number of local hidden variables: number of local hidden variables = Number of sum nodes = Θ( S ). 9 / 26
Motivations and Challenges W 1 W 2 W 3 W m H 1 H 2 H 3 H m X 1 X 2 X 3 X n D W Model parameters, global hidden variables. H Assignments of sum nodes, local hidden variables. X Observable variables. D Number of instances. Challenges for standard VB: Large number of local hidden variables: number of local hidden variables = Number of sum nodes = Θ( S ). Memory overhead: space complexity O(D S ). 10 / 26
Motivations and Challenges W 1 W 2 W 3 W m H 1 H 2 H 3 H m X 1 X 2 X 3 X n D W Model parameters, global hidden variables. H Assignments of sum nodes, local hidden variables. X Observable variables. D Number of instances. Challenges for standard VB: Large number of local hidden variables: number of local hidden variables = Number of sum nodes = Θ( S ). Memory overhead: space complexity O(D S ). Time complexity: O(nD S ). 11 / 26
Contributions Our contributions: We obtain better ELBO L to optimize than L, the one obtained by mean-field. Reduced space complexity: O(D S ) O( S ), space complexity is independent of training size. Reduced time complexity: O(nD S ) O(D S ), removing the explicit dependency on the dimension. 12 / 26
Efficient Marginalization Recall ELBO in standard VI: L := E q(w,h) [log p(w, h, x)] + H[q(w, h)] Consider the new ELBO in Collapsed VI: L :=E q(w) [log p(w, x)] + H[q(w)] =E q(w) [log h p(w, h, x)] + H[q(w)] We can establish the following inequality: log p(x) L L The new ELBO in Collapsed VI leads to a better lower bound than the one used in standard VI! 13 / 26
Comparisons Standard Variational Inference Mean-field assumption: q(w, h) = i q(w i) j q(h j) ELBO: L := E q(w,h) [log p(w, h, x)] + H[q(w, h)] Collapsed Variational Inference for LDA, HDP Collapsed out global hidden variables: q(h) = w q(w, h) dw ELBO: L h := E q(h) [log p(h, x)] + H[q(h)] Better lower bound: L h L Collapsed Variational Inference for SPN Collapsed out local hidden variables: q(w) = h q(w, h) ELBO: L w := E q(w) [log p(w, x)] + H[q(w)] Better lower bound: L w L 14 / 26
Efficient Marginalization Time complexity of the exact marginalization incurred in computing h p(w, h, x): Time complexity of marginalization in graphical model G: O(D 2 tw(g) ). Space complexity reduction: 15 / 26
Efficient Marginalization Time complexity of the exact marginalization incurred in computing h p(w, h, x): Time complexity of marginalization in graphical model G: O(D 2 tw(g) ). Exact marginalization in BN B with algebraic decision diagram as local factors: O(D B ) = O(nD S ). Space complexity reduction: 16 / 26
Efficient Marginalization Time complexity of the exact marginalization incurred in computing h p(w, h, x): Time complexity of marginalization in graphical model G: O(D 2 tw(g) ). Exact marginalization in BN B with algebraic decision diagram as local factors: O(D B ) = O(nD S ). Exact marginalization in SPN S: O(D S ). Space complexity reduction: 17 / 26
Efficient Marginalization Time complexity of the exact marginalization incurred in computing h p(w, h, x): Time complexity of marginalization in graphical model G: O(D 2 tw(g) ). Exact marginalization in BN B with algebraic decision diagram as local factors: O(D B ) = O(nD S ). Exact marginalization in SPN S: O(D S ). Space complexity reduction: No posterior over h to approximate anymore. 18 / 26
Efficient Marginalization Time complexity of the exact marginalization incurred in computing h p(w, h, x): Time complexity of marginalization in graphical model G: O(D 2 tw(g) ). Exact marginalization in BN B with algebraic decision diagram as local factors: O(D B ) = O(nD S ). Exact marginalization in SPN S: O(D S ). Space complexity reduction: No posterior over h to approximate anymore. No variational variables over h needed: O(D S ) O( S ). 19 / 26
Logarithmic Transformation New optimization objective: maximize q Q E q(w) [log h p(w, h, x)] + H[q(w)] which is equivalent to minimize q Q KL[q(w) p(w)] E q(w) [log p(x w)] p(w) prior distribution over w, product of Dirichlets. q(w) variational posterior over w, product of Dirichlets. p(x w) likelihood, not multinomial anymore after marginalization. Non-conjugate q(w) and p(x w), no analytical solution for E q(w) [log p(x w)]. 20 / 26
Logarithmic Transformation Key observation: p(x w) = V root (x w) = τ S w kj t=1 (k,j) T te i=1 n p t (X i = x i ) is a posynomial function of w. Make a bijective mapping (change of variable): w = log(w). Dates back to the literature of geometric programming. The new objective after transformation is convex in w. τ S log p(x w) = log exp c t + t=1 (k,j) T te w kj Jensen s inequality to obtain further lower bound. 21 / 26
Logarithmic Transformation Further lower bound: E q(w) [log p(x w)] = E q(w )[log p(x w )] log p(x E q (w )[w ]) Relaxed objective: minimize q Q KL[q(w) p(w)] log p(x E }{{} q (w )[w ]) }{{} Regularity Data fitting Roughly, log p(x E q (w )[w ]) corresponds the log-likelihood by setting the weights of SPN as the posterior mean of q(w). Optimized by projected GD. 22 / 26
Algorithm Line 4 8 easily parallelizable, distributed version. Sample minibatch in Line 4 8, stochastic version. 23 / 26
Experiments Experiments on 20 data sets, report average log-likelihoods, Wilcoxon ranked test. Compared with (O)MLE-SPN and OBMM-SPN. 0 MLE-SPN vs CVB-SPN, Avg. log-likelihoods 0 OMLE-SPN vs OCVB-SPN, Avg. log-likelihoods 50 100 100 200 CVB-Projected GD 150 200 250 OCVB-Projected GD 300 400 500 300 600 350 700 400 400 350 300 250 200 150 100 50 0 MLE-Projected GD 800 800 700 600 500 400 300 200 100 0 OMLE-Projected GD 24 / 26
Experiments 0 OBMM vs OCVB, Avg. log-likelihoods 50 OCVB-Projected 100 150 200 250 300 300 250 200 150 100 50 0 OBMM-Moment Matching 25 / 26
Summary Thanks Q & A 26 / 26