AMP-inspired Deep Networks, with Comms Applications Phil Schniter Collaborators: Sundeep Rangan (NYU), Alyson Fletcher (UCLA), Mark Borgerding (OSU) Supported in part by NSF grants IIP-1539960, CCF-1527162, and CCF-1716388. Intl. Conf. on Signal Processing and Communications July 2018
Deep Neural Networks Typical feedforward setup: Many layers, consisting of (affine) linear stages and scalar nonlinearities. rag replacements LIN LIN... LIN non-linear Linear stages often constrained (e.g., small convolution kernels). Parameters learned by minimizing training error using backpropagation. Open questions: 1 How should we interpret the learned parameters? 2 Can we speed up training? 3 Can we design a better network structure? Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 2 / 19
Focus of this talk: Standard Linear Regression Consider recovering a vector x from noisy linear observations y = Ax+w, where x is drawn from an iid prior (e.g., sparse 1 ) entsfor this application, we propose a deep network that is 1) asymptotically optimal for a large class of A, 2) interpretable, and 3) easy to train. pling pling LIN LIN... linear non-linear LIN 1 Gregor/LeCun, Sprechmann/Bronstein/Sapiro, Kamilov/Mansour, Wang/Ling/Huang, Mousavi/Baraniuk, Xin/Wang/Gao/Wipf, Borgerding/Schniter/Rangan, etc. Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 3 / 19
Understanding Deep Networks via Algorithms Many algorithms have been proposed for linear regression. Often, such algorithms are iterative, where each iteration consists of a linear operation followed by scalar nonlinearities. By unfolding such algorithms, we get deep networks. 2 rag replacements LIN LIN... LIN non-linear Can such algorithms help us design/interpret/train deep nets? 2 Gregor/LeCun, ICML 10. Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 4 / 19
Algorithmic Approaches to Standard Linear Regression Recall goal: recovering/fitting x from noisy linear observations y = Ax+w. A popular approach is regularized loss minimization: arg min x where, e.g., f(x) = x 1 for the lasso. 1 2 y Ax 2 +σ 2 f(x), Can also be interpreted as MAP estimation of x under priors x exp( f(x)) & w N(0,σ 2 I). But often the goal is minimizing MSE or inferring marginal posteriors. Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 5 / 19
High-dimensional MMSE Inference High dimensional MMSE inference is difficult in general. To simplify things, suppose that 1) x is iid 2) A is large and random. The case of iid Gaussian A is well studied, but very restrictive. Instead, consider right-rotationally invariant (RRI) A: A = USV T with V Haar and indep of x. For this case, the MMSE is 34 E(γ) = var{x r}, r = x+n(0,1/γ), γ = R A T A/σ ( E(γ)) 2 3 Tulino/Caire/Verdu/Shamai, IEEE-TIT 13, 4 Reeves, Allerton 17. Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 6 / 19
Achieving MMSE in standard linear regression Recently a vector approximate message passing (VAMP) algorithm has been proposed that iterates linear vector estimation with nonlinear scalar denoising. (Closely related to expectation propagation. 5 ) Under large RRI A and Lipschitz denoisers, VAMP is rigorously characterized by a scalar state-evolution. 6 When the state-evolution has a unique fixed point, the VAMP solutions are MMSE! 5 Opper/Winther, JMLR 05, 6 Rangan/Schniter/Fletcher, arxiv:1610.03082. Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 7 / 19
VAMP for linear regression initialize r 1,γ 1 for t = 0,1,2,... x 1 ( A T A/σ 2 +γ 1 I ) 1( A T y/σ 2 ) +γ 1 r 1 LMMSE [ α 1 γ 1 (A N Tr T A/σ 2 +γ 1 I ) ] 1 divergence r 2 1 ) 1 α ( x1 1 α 1 r 1 Onsager correction 1 α γ 2 γ 1 1 α 1 precision of r 2 x 2 g(r 2 ;γ 2 ) (scalar) denoising [ ] α 2 1 N Tr g r (r 2;γ 2 ) divergence r 1 1 ) 1 α ( x2 2 α 2 r 2 Onsager correction 1 α γ 1 γ 2 2 α 2 precision of r 1 end Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 8 / 19
MMSE-VAMP interpreted initialize r 1,γ 1 for t = 0,1,2,... end x 1 MMSE estimate of x under pseudo-prior N(r 1,I/γ 1 ) & measurement y = Ax+N(0,σ 2 I) r 2 linear cancellation of r 1 from x 1 x 2 MMSE estimate of x under prior p(x) and pseudo-measurement r 2 = x+n(0,i/γ 2 ) r 1 linear cancellation of r 2 from x 2 Linear cancellation decouples the iterations, so that global MMSE problem can be tackled by solving simpler local MMSE problems. Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 9 / 19
Unfolding VAMP ents Unfolding the VAMP algorithm gives the network 7 ling ling near LIN LIN... LIN non-linear Notice the two stages in each layer. 7 Borgerding/Schniter, IEEE-TSP 17 Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 10 / 19
Learning the network parameters After unfolding an algorithm, one can use backpropagation (or similar) to learn the optimal network parameters. 8 Linear stage: x 1 = Br 1 +Cy learn (B,C) for each layer. Nonlinear stage: x 2,j = g(r 2,j ) j learn a scalar function g( ) for each layer. e.g., spline, piecewise linear, etc. 8 Gregor/LeCun, ICML 10. Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 11 / 19
Result of learning Suppose that the training data {y (d),x (d) } D d=1 1) iid x (d) j p 2) y (d) = Ax (d) +N(0,σ 2 I) 3) right-rotationally invariant A. were constructed using Backpropagation recovers the parameter settings (B, C, g) originally prescribed by the VAMP algorithm! Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 12 / 19
Result of learning Suppose that the training data {y (d),x (d) } D d=1 1) iid x (d) j p 2) y (d) = Ax (d) +N(0,σ 2 I) 3) right-rotationally invariant A. were constructed using Backpropagation recovers the parameter settings (B, C, g) originally prescribed by the VAMP algorithm! The learned linear stages are MMSE under pseudo-prior x N(r 1,I/γ 1 ) & measurement y = Ax+N(0,σ 2 I) The learned scalar nonlinearities are MMSE under prior x j p and pseudo-measurements r 2,j = x j +N(0,1/γ 2 ) This deep network is interpretable! Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 12 / 19
Implications for training Due to the stages... each linear stage is locally MSE optimal, so the linear params (B,C) can be learned locally in each layer! each non-linear stage is locally MSE optimal, so the nonlinear function g( ) can be learned locally in each layer! This deep network is easy to train! Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 13 / 19
Example with iid Gaussian A iid Gaussian ts NMSE (db) -10-15 -20-25 -30-35 -40-45 LISTA LVAMP-pwlin matched VAMP MMSE 2 4 6 8 10 12 14 Layers n = 1024 m/n = 0.5 A iid N(0,1) x Bernoulli-Gaussian Pr{x 0} = 0.1 SNR = 40 db Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 14 / 19
Example with non-iid Gaussian A NMSE (db) -10-15 -20-25 -30-35 -40 condition number = 15 LISTA LVAMP-pwlin matched VAMP MMSE n = 1024 m/n = 0.5 A = USV T U,V Haar s n /s n 1 = φ n x Bernoulli-Gaussian Pr{x 0} = 0.1 ts -45 2 4 6 8 10 12 14 Layers SNR = 40 db Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 15 / 19
Deep nets vs algorithms NMSE [db] 0-10 -20-30 -40 ISTA FISTA AMP-L1 LISTA AMP-BG LVAMP-pwlin MMSE n = 1024 m/n = 0.5 A iid N(0,1) x Bernoulli-Gaussian Pr{x 0} = 0.1 ts -50 10 0 10 1 10 2 10 3 10 4 layers / iterations SNR = 40 db Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 16 / 19
Application: Compressive random access Devices periodically wake up and broadcast short data bursts. The BS must simultaneous detect which devices are active and estimate their multipath gains. NMSE (db) We use deep networks for this. PSfrag replacements -4-6 -8-10 -12-14 -16 LISTA LISTAuntied LAMPbg LVAMPbg LAMPpwgrid LAMPpwgriduntied LVAMP-pwgrid Internet of Things 1 2 3 4 5 6 Layers Experiment: 512 users, 1% active single antenna BS Pilots: i.i.d. QPSK, length 64 flat Ricean fading SNR: 10dB Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 17 / 19
Application: Massive-MIMO channel estimation Massive-array BS communicates with single-antenna mobiles. Pilot reuse contaminates channel estimates. Random pilots can alleviate this problem. in-cell NMSE (db) We use deep networks to simultaneously estimate channels PSfrag replacements of in-cell and out-of-cell users. -7-8 -9-10 -11-12 -13 LISTA LISTAuntied LAMPbg LVAMPbg LAMPpwgrid LAMPpwgriduntied LVAMP-pwgrid Base Station -14 1 2 3 4 5 6 Layers Experiment: BS: 64-element ULA 64 in-cell users, 448 total users Pilots: i.i.d. QPSK, length 64 flat Ricean fading SNR: 20dB Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 18 / 19
Conclusions Our goal is to understand the design and interpretation of deep nets. For this talk, we focused on the task of sparse linear regression. We proposed a deep net that is 1) asymptotically MSE-optimal (for iid x and RRI A) 2) interpretable:... LMMSE//NL-MMSE/... 3) locally trainable. The proposed network is obtained by unfolding the VAMP algorithm and learning its parameters. We demonstrated the approach in wireless comms applications. Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 19 / 19
Thanks! TensorFlow code at https://github.com/mborgerding/onsager_deep_learning Phil Schniter (Ohio State) AMP-inspired Deep Networks SPCOM 18 19 / 19