Mini-Course 1: SGD Escapes Saddle Points

Similar documents
Course Notes for EE227C (Spring 2018): Convex Optimization and Approximation

How to Escape Saddle Points Efficiently? Praneeth Netrapalli Microsoft Research India

Non-convex optimization. Issam Laradji

Non-Convex Optimization. CS6787 Lecture 7 Fall 2017

Non-Convex Optimization in Machine Learning. Jan Mrkos AIC

SVRG Escapes Saddle Points

Overparametrization for Landscape Design in Non-convex Optimization

On the fast convergence of random perturbations of the gradient flow.

arxiv: v1 [cs.lg] 2 Mar 2017

Contents. 1 Introduction. 1.1 History of Optimization ALG-ML SEMINAR LISSA: LINEAR TIME SECOND-ORDER STOCHASTIC ALGORITHM FEBRUARY 23, 2016

ECS171: Machine Learning

A Conservation Law Method in Optimization

CSC321 Lecture 7: Optimization

Stochastic Variance Reduction for Nonconvex Optimization. Barnabás Póczos

Optimization in the Big Data Regime 2: SVRG & Tradeoffs in Large Scale Learning. Sham M. Kakade

Advanced computational methods X Selected Topics: SGD

Third-order Smoothness Helps: Even Faster Stochastic Optimization Algorithms for Finding Local Minima

Selected Topics in Optimization. Some slides borrowed from

CSC321 Lecture 8: Optimization

arxiv: v2 [math.oc] 5 Nov 2017

A random perturbation approach to some stochastic approximation algorithms in optimization.

Machine Learning CS 4900/5900. Lecture 03. Razvan C. Bunescu School of Electrical Engineering and Computer Science

Lecture 6 Optimization for Deep Neural Networks

Stochastic Optimization Algorithms Beyond SG

Optimization for Training I. First-Order Methods Training algorithm

CSCI 1951-G Optimization Methods in Finance Part 12: Variants of Gradient Descent

Nonlinear Optimization Methods for Machine Learning

Gradient Descent Can Take Exponential Time to Escape Saddle Points

IFT Lecture 6 Nesterov s Accelerated Gradient, Stochastic Gradient Descent

Neural Network Training

Introduction to gradient descent

Why should you care about the solution strategies?

Day 3 Lecture 3. Optimizing deep networks

Simple Techniques for Improving SGD. CS6787 Lecture 2 Fall 2017

Overview of gradient descent optimization algorithms. HYUNG IL KOO Based on

Big Data Analytics. Lucas Rego Drumond

Optimization. Benjamin Recht University of California, Berkeley Stephen Wright University of Wisconsin-Madison

Information theoretic perspectives on learning algorithms

Part 3: Trust-region methods for unconstrained optimization. Nick Gould (RAL)

Math (P)refresher Lecture 8: Unconstrained Optimization

CS260: Machine Learning Algorithms

Unconstrained optimization

ECE G: Special Topics in Signal Processing: Sparsity, Structure, and Inference

COR-OPT Seminar Reading List Sp 18

CPSC 540: Machine Learning

arxiv: v1 [math.oc] 9 Oct 2018

How to Characterize the Worst-Case Performance of Algorithms for Nonconvex Optimization

Optimization and Gradient Descent

arxiv: v1 [cs.lg] 17 Nov 2017

ECE521 lecture 4: 19 January Optimization, MLE, regularization

what can deep learning learn from linear regression? Benjamin Recht University of California, Berkeley

Gradient Descent. Dr. Xiaowei Huang

Higher-Order Methods

Optimization for neural networks

A Quick Tour of Linear Algebra and Optimization for Machine Learning

Accelerating Stochastic Optimization

Algorithmic Stability and Generalization Christoph Lampert

A Stochastic PCA Algorithm with an Exponential Convergence Rate. Ohad Shamir

Sub-Sampled Newton Methods

Large-scale Stochastic Optimization

Stochastic Gradient Descent. CS 584: Big Data Analytics

Linear Regression (continued)

Stochastic Gradient Descent. Ryan Tibshirani Convex Optimization

1 What a Neural Network Computes

An Evolving Gradient Resampling Method for Machine Learning. Jorge Nocedal

Complexity analysis of second-order algorithms based on line search for smooth nonconvex optimization

Iterative Methods for Solving A x = b

arxiv: v2 [math.oc] 1 Nov 2017

OPER 627: Nonlinear Optimization Lecture 9: Trust-region methods

Gradient Descent. Sargur Srihari

Introduction to Optimization

Lecture 1: Supervised Learning

arxiv: v1 [math.oc] 12 Oct 2018

CS60021: Scalable Data Mining. Large Scale Machine Learning

A picture of the energy landscape of! deep neural networks

GRADIENT DESCENT. CSE 559A: Computer Vision GRADIENT DESCENT GRADIENT DESCENT [0, 1] Pr(y = 1) w T x. 1 f (x; θ) = 1 f (x; θ) = exp( w T x)

STA141C: Big Data & High Performance Statistical Computing

Optimisation non convexe avec garanties de complexité via Newton+gradient conjugué

Deep Feedforward Networks

arxiv: v4 [math.oc] 5 Jan 2016

Based on the original slides of Hung-yi Lee

arxiv: v4 [math.oc] 24 Apr 2017

Lecture 5: September 12

CS 542G: Robustifying Newton, Constraints, Nonlinear Least Squares

Comments. Assignment 3 code released. Thought questions 3 due this week. Mini-project: hopefully you have started. implement classification algorithms

Convergence of Cubic Regularization for Nonconvex Optimization under KŁ Property

Tutorial: PART 2. Optimization for Machine Learning. Elad Hazan Princeton University. + help from Sanjeev Arora & Yoram Singer

arxiv: v4 [math.oc] 11 Jun 2018

Need for Deep Networks Perceptron. Can only model linear functions. Kernel Machines. Non-linearity provided by kernels

Convex Optimization Lecture 16

Accelerated Block-Coordinate Relaxation for Regularized Optimization

min f(x). (2.1) Objectives consisting of a smooth convex term plus a nonconvex regularization term;

Subgradient Method. Guest Lecturer: Fatma Kilinc-Karzan. Instructors: Pradeep Ravikumar, Aarti Singh Convex Optimization /36-725

Characterization of Gradient Dominance and Regularity Conditions for Neural Networks

Course Notes for EE227C (Spring 2018): Convex Optimization and Approximation

Composite nonlinear models at scale

Incremental Reshaped Wirtinger Flow and Its Connection to Kaczmarz Method

CPSC 340: Machine Learning and Data Mining. Stochastic Gradient Fall 2017

Coordinate Descent and Ascent Methods

Overfitting, Bias / Variance Analysis

Transcription:

Mini-Course 1: SGD Escapes Saddle Points Yang Yuan Computer Science Department Cornell University

Gradient Descent (GD) Task: min x f (x) GD does iterative updates x t+1 = x t η t f (x t )

Gradient Descent (GD) Task: min x f (x) GD does iterative updates x t+1 = x t η t f (x t )

Gradient Descent (GD) has at least two problems

Gradient Descent (GD) has at least two problems Computing the full gradient is slow for big data.

Gradient Descent (GD) has at least two problems Computing the full gradient is slow for big data. Stuck at stationary points.

Stochastic Gradient Descent (SGD) Very similar to GD, gradient now has some randomness: x t+1 = x t η t g t, where E[g t ] = f (x t ).

Stochastic Gradient Descent (SGD) Very similar to GD, gradient now has some randomness: x t+1 = x t η t g t, where E[g t ] = f (x t ).

Why do we use SGD? Initially because:

Why do we use SGD? Initially because: Much cheaper to compute using mini-batch

Why do we use SGD? Initially because: Much cheaper to compute using mini-batch Can still converge to global minimum in convex case

Why do we use SGD? Initially because: Much cheaper to compute using mini-batch Can still converge to global minimum in convex case But now people realize:

Why do we use SGD? Initially because: Much cheaper to compute using mini-batch Can still converge to global minimum in convex case But now people realize: Can escape saddle points! (Today s topic)

Why do we use SGD? Initially because: Much cheaper to compute using mini-batch Can still converge to global minimum in convex case But now people realize: Can escape saddle points! (Today s topic) Can escape shallow local minima (Next time s topic, some progress.)

Why do we use SGD? Initially because: Much cheaper to compute using mini-batch Can still converge to global minimum in convex case But now people realize: Can escape saddle points! (Today s topic) Can escape shallow local minima (Next time s topic, some progress.) Can find local minima that generalize well (Not well understood)

Why do we use SGD? Initially because: Much cheaper to compute using mini-batch Can still converge to global minimum in convex case But now people realize: Can escape saddle points! (Today s topic) Can escape shallow local minima (Next time s topic, some progress.) Can find local minima that generalize well (Not well understood) Therefore, it s not only faster, but also works better!

About g t that we use x t+1 = x t η t g t, where E[g t ] = f (x t ).

About g t that we use x t+1 = x t η t g t, where E[g t ] = f (x t ). In practice, g t is obtained by sampling a minibatch of size 128 or 256 from the dataset

About g t that we use x t+1 = x t η t g t, where E[g t ] = f (x t ). In practice, g t is obtained by sampling a minibatch of size 128 or 256 from the dataset To simplify the analysis, we assume where ξ t N(0, I) or B 0 (r) g t = f (x t ) + ξ t

About g t that we use x t+1 = x t η t g t, where E[g t ] = f (x t ). In practice, g t is obtained by sampling a minibatch of size 128 or 256 from the dataset To simplify the analysis, we assume where ξ t N(0, I) or B 0 (r) g t = f (x t ) + ξ t In general, if ξ t has non-negligible components on every direction, the analysis works.

Preliminaries L-Lipschitz, i.e., f (w 1 ) f (w 2 ) L w 1 w 2 2

Preliminaries L-Lipschitz, i.e., f (w 1 ) f (w 2 ) L w 1 w 2 2 l-smoothness: The gradient is l-lipschitz, i.e. f (w 1 ) f (w 2 ) 2 l w 1 w 2 2

Preliminaries L-Lipschitz, i.e., f (w 1 ) f (w 2 ) L w 1 w 2 2 l-smoothness: The gradient is l-lipschitz, i.e. f (w 1 ) f (w 2 ) 2 l w 1 w 2 2 ρ-hessian smoothness: The hessian matrix is ρ-lipschitz, i.e., 2 f (w 1 ) 2 f (w 2 ) sp ρ w 1 w 2 2

Preliminaries L-Lipschitz, i.e., f (w 1 ) f (w 2 ) L w 1 w 2 2 l-smoothness: The gradient is l-lipschitz, i.e. f (w 1 ) f (w 2 ) 2 l w 1 w 2 2 ρ-hessian smoothness: The hessian matrix is ρ-lipschitz, i.e., 2 f (w 1 ) 2 f (w 2 ) sp ρ w 1 w 2 2 We need this because we will use the Hessian at the current spot to approximate the neighborhood

Preliminaries L-Lipschitz, i.e., f (w 1 ) f (w 2 ) L w 1 w 2 2 l-smoothness: The gradient is l-lipschitz, i.e. f (w 1 ) f (w 2 ) 2 l w 1 w 2 2 ρ-hessian smoothness: The hessian matrix is ρ-lipschitz, i.e., 2 f (w 1 ) 2 f (w 2 ) sp ρ w 1 w 2 2 We need this because we will use the Hessian at the current spot to approximate the neighborhood Then bound the approximation.

Saddle points, and negative eigenvalue

Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0,

Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum.

Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum. If 2 f (w) 0, it s a local maximum.

Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum. If 2 f (w) 0, it s a local maximum. If 2 f (w) has both +/ eigenvalues, it s a saddle point.

Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum. If 2 f (w) 0, it s a local maximum. If 2 f (w) has both +/ eigenvalues, it s a saddle point. Degenerate case: 2 f (w) has eigenvalues equal to 0. It could be either local minimum(maximum)/saddle point.

Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum. If 2 f (w) 0, it s a local maximum. If 2 f (w) has both +/ eigenvalues, it s a saddle point. Degenerate case: 2 f (w) has eigenvalues equal to 0. It could be either local minimum(maximum)/saddle point. f is flat on some directions

Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum. If 2 f (w) 0, it s a local maximum. If 2 f (w) has both +/ eigenvalues, it s a saddle point. Degenerate case: 2 f (w) has eigenvalues equal to 0. It could be either local minimum(maximum)/saddle point. f is flat on some directions SGD is like random walk

Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum. If 2 f (w) 0, it s a local maximum. If 2 f (w) has both +/ eigenvalues, it s a saddle point. Degenerate case: 2 f (w) has eigenvalues equal to 0. It could be either local minimum(maximum)/saddle point. f is flat on some directions SGD is like random walk We only consider non-degenerate case!

Strict saddle property f (w) is (α, γ, ɛ, ζ)-strict saddle, if for any w,

Strict saddle property f (w) is (α, γ, ɛ, ζ)-strict saddle, if for any w, f (w) 2 ɛ Which means: Gradient is large

Strict saddle property f (w) is (α, γ, ɛ, ζ)-strict saddle, if for any w, f (w) 2 ɛ or, λ min 2 f (w) γ < 0 Which means: Gradient is large or (stationary point), we have a negative eigenvalue direction to escape

Strict saddle property f (w) is (α, γ, ɛ, ζ)-strict saddle, if for any w, f (w) 2 ɛ or, λ min 2 f (w) γ < 0 or, there exists w such that w w 2 ζ, and the region centered w with radius 2ζ is α-strongly convex. Which means: Gradient is large or (stationary point), we have a negative eigenvalue direction to escape or (stationary point, no negative eigenvalues), we are pretty close to a local minimum.

Strict saddle functions are everywhere Orthogonal tensor decomposition [Ge et al 2015] Deep linear (residual) networks [Kawaguchi 2016], [Hardt and Ma 2016] Matrix completion [Ge et al 2016] Generalized phase retrieval problem [Sun et al 2016] Low rank matrix recovery [Bhojanapalli et al 2016]

Strict saddle functions are everywhere Orthogonal tensor decomposition [Ge et al 2015] Deep linear (residual) networks [Kawaguchi 2016], [Hardt and Ma 2016] Matrix completion [Ge et al 2016] Generalized phase retrieval problem [Sun et al 2016] Low rank matrix recovery [Bhojanapalli et al 2016] Moreover, in these problems, all local minima are equally good!

Strict saddle functions are everywhere Orthogonal tensor decomposition [Ge et al 2015] Deep linear (residual) networks [Kawaguchi 2016], [Hardt and Ma 2016] Matrix completion [Ge et al 2016] Generalized phase retrieval problem [Sun et al 2016] Low rank matrix recovery [Bhojanapalli et al 2016] Moreover, in these problems, all local minima are equally good! That means,

Strict saddle functions are everywhere Orthogonal tensor decomposition [Ge et al 2015] Deep linear (residual) networks [Kawaguchi 2016], [Hardt and Ma 2016] Matrix completion [Ge et al 2016] Generalized phase retrieval problem [Sun et al 2016] Low rank matrix recovery [Bhojanapalli et al 2016] Moreover, in these problems, all local minima are equally good! That means, SGD escapes all saddle points

Strict saddle functions are everywhere Orthogonal tensor decomposition [Ge et al 2015] Deep linear (residual) networks [Kawaguchi 2016], [Hardt and Ma 2016] Matrix completion [Ge et al 2016] Generalized phase retrieval problem [Sun et al 2016] Low rank matrix recovery [Bhojanapalli et al 2016] Moreover, in these problems, all local minima are equally good! That means, SGD escapes all saddle points So, SGD arrives one local minimum global minimum!

Strict saddle functions are everywhere Orthogonal tensor decomposition [Ge et al 2015] Deep linear (residual) networks [Kawaguchi 2016], [Hardt and Ma 2016] Matrix completion [Ge et al 2016] Generalized phase retrieval problem [Sun et al 2016] Low rank matrix recovery [Bhojanapalli et al 2016] Moreover, in these problems, all local minima are equally good! That means, SGD escapes all saddle points So, SGD arrives one local minimum global minimum! One popular way to prove SGD solves the problem.

Main Results [Ge et al 2015] says, whp, SGD will escape all saddle points, and converge to a local minimum. The convergence time has polynomial dependency in dimension d.

Main Results [Ge et al 2015] says, whp, SGD will escape all saddle points, and converge to a local minimum. The convergence time has polynomial dependency in dimension d. [Jin et al 2017] says, whp, PGD (a variant of SGD) will escape all saddle points, and converge to a local minimum much faster. The dependence in d is logarithmic.

Main Results [Ge et al 2015] says, whp, SGD will escape all saddle points, and converge to a local minimum. The convergence time has polynomial dependency in dimension d. [Jin et al 2017] says, whp, PGD (a variant of SGD) will escape all saddle points, and converge to a local minimum much faster. The dependence in d is logarithmic. Same proof framework. We ll mainly look at the new result.

Description of PGD Do the following iteratively:

Description of PGD Do the following iteratively: Do a gradient descent step: x t+1 = x t η f (x t )

Description of PGD Do the following iteratively: If f (x t ) g thres, and last perturbed time is > t thres steps before, do random perturbation (ball) Do a gradient descent step: x t+1 = x t η f (x t )

Description of PGD Do the following iteratively: If f (x t ) g thres, and last perturbed time is > t thres steps before, do random perturbation (ball) If perturbation happened t thres steps ago, but f is decreased for less than f thres, return the value before last perturbation Do a gradient descent step: x t+1 = x t η f (x t )

Description of PGD Do the following iteratively: If f (x t ) g thres, and last perturbed time is > t thres steps before, do random perturbation (ball) If perturbation happened t thres steps ago, but f is decreased for less than f thres, return the value before last perturbation Do a gradient descent step: x t+1 = x t η f (x t ) A few Remarks:

Description of PGD Do the following iteratively: If f (x t ) g thres, and last perturbed time is > t thres steps before, do random perturbation (ball) If perturbation happened t thres steps ago, but f is decreased for less than f thres, return the value before last perturbation Do a gradient descent step: x t+1 = x t η f (x t ) A few Remarks: Unfortunately.. Not a fast algorithm because of GD!

Description of PGD Do the following iteratively: If f (x t ) g thres, and last perturbed time is > t thres steps before, do random perturbation (ball) If perturbation happened t thres steps ago, but f is decreased for less than f thres, return the value before last perturbation Do a gradient descent step: x t+1 = x t η f (x t ) A few Remarks: Unfortunately.. Not a fast algorithm because of GD! η = c l. g thres, t thres, f thres depends on a constant c, as well as other parameters.

Main theorem in [Jin et al 2017] Theorem (Main Theorem) Assume function f is l-smooth and ρ-hessian Lipschitz, (α, γ, ɛ, ζ)-strict saddle. There exists an absolute constant c max such that, for any δ > 0, f f (x 0 ) f, and constant c c max, ɛ = min{ɛ, γ2 ρ }, PGD(c) will output a point ζ-close to a local minimum, with probability 1 δ, and terminate in the following number of iterations: O ( l(f (x0 ) f ) ɛ 2 log 4 ( dl f ɛ 2 δ ))

Main theorem in [Jin et al 2017] Theorem (Main Theorem) Assume function f is l-smooth and ρ-hessian Lipschitz, (α, γ, ɛ, ζ)-strict saddle. There exists an absolute constant c max such that, for any δ > 0, f f (x 0 ) f, and constant c c max, ɛ = min{ɛ, γ2 ρ }, PGD(c) will output a point ζ-close to a local minimum, with probability 1 δ, and terminate in the following number of iterations: O ( l(f (x0 ) f ) ɛ 2 log 4 ( dl f ɛ 2 δ )) If could show SGD has similar property, would be great!

Main theorem in [Jin et al 2017] Theorem (Main Theorem) Assume function f is l-smooth and ρ-hessian Lipschitz, (α, γ, ɛ, ζ)-strict saddle. There exists an absolute constant c max such that, for any δ > 0, f f (x 0 ) f, and constant c c max, ɛ = min{ɛ, γ2 ρ }, PGD(c) will output a point ζ-close to a local minimum, with probability 1 δ, and terminate in the following number of iterations: O ( l(f (x0 ) f ) ɛ 2 log 4 ( dl f ɛ 2 δ )) If could show SGD has similar property, would be great! The convergence rate is almost optimal.

More general version: why it s fast Theorem (A more general version) Assume function f is l-smooth and ρ-hessian Lipschitz. There exists an absolute constant c max such that, for any δ > 0, f f (x 0 ) f, and constant c c max, ɛ l2 ρ, PGD(c) will output a point ζ-close to an ɛ-second-order stationary point, with probability 1 δ, and terminate in the following number of iterations: ( l(f (x0 ) f ( )) ) O ɛ 2 log 4 dl f ɛ 2 δ Essentially saying the same thing. If f is not strict saddle, only ɛ-second-order stationary point (instead of local minimum) is guaranteed.

ɛ-stationary points ɛ-first-order stationary point: f (x) ɛ

ɛ-stationary points ɛ-first-order stationary point: f (x) ɛ ɛ-second-order stationary point: f (x) ɛ, λ min ( 2 f (x)) ρɛ

ɛ-stationary points ɛ-first-order stationary point: f (x) ɛ ɛ-second-order stationary point: f (x) ɛ, λ min ( 2 f (x)) ρɛ If l-smooth, λmin ( 2 f (x)) l.

ɛ-stationary points ɛ-first-order stationary point: f (x) ɛ ɛ-second-order stationary point: f (x) ɛ, λ min ( 2 f (x)) ρɛ If l-smooth, λmin ( 2 f (x)) l. For any ɛ > l 2 ρ, an ɛ-first-order stationary point in a l-smooth function is a l2 ρ -second-order stationary point

ɛ-stationary points ɛ-first-order stationary point: f (x) ɛ ɛ-second-order stationary point: f (x) ɛ, λ min ( 2 f (x)) ρɛ If l-smooth, λmin ( 2 f (x)) l. For any ɛ > l 2 ρ, an ɛ-first-order stationary point in a l-smooth function is a l2 ρ -second-order stationary point If (α, γ, ɛ, ζ)-strict saddle, and ɛ < γ 2 ρ, then any ɛ-second-order stationary point is a local minimum.

[Nesterov, 1998] Theorem Assume that f is l-smooth. Then for any ɛ > 0, if we run GD with step size η = 1 l and termination condition f (x) ɛ, the output will be ɛ-first-order stationary point, and the algorithm terminates in the following number of iterations: l(f (x 0 ) f ) ɛ 2 [Jin et ( al 2017]: PGD ( converges )) to ɛ-second-order stationary point in O l(f (x0 ) f ) log 4 dl f steps. ɛ 2 ɛ 2 δ Matched up to log factors!

Why ρɛ? If we use third order approximation for x [Nesterov and Polyak, 2006] min { f (x), y x + 12 2 f (x)(y x), y x + ρ6 } y x 2 y denote the answer as T x.

Why ρɛ? If we use third order approximation for x [Nesterov and Polyak, 2006] min { f (x), y x + 12 2 f (x)(y x), y x + ρ6 } y x 2 y denote the answer as T x. Denote distance r = x T x

Why ρɛ? If we use third order approximation for x [Nesterov and Polyak, 2006] min { f (x), y x + 12 2 f (x)(y x), y x + ρ6 } y x 2 y denote the answer as T x. Denote distance r = x T x f (T x ) ρr 2, 2 f (T x ) 3 2 ρri

Why ρɛ? If we use third order approximation for x [Nesterov and Polyak, 2006] min { f (x), y x + 12 2 f (x)(y x), y x + ρ6 } y x 2 y denote the answer as T x. Denote distance r = x T x f (T x ) ρr 2, 2 f (T x ) 3 2 ρri To get a lower bound for r: { } f (T x ) max, 2 ρ 3ρ λ min 2 f (T x )

Why ρɛ? If we use third order approximation for x [Nesterov and Polyak, 2006] min { f (x), y x + 12 2 f (x)(y x), y x + ρ6 } y x 2 y denote the answer as T x. Denote distance r = x T x f (T x ) ρr 2, 2 f (T x ) 3 2 ρri To get a lower bound for r: { } f (T x ) max, 2 ρ 3ρ λ min 2 f (T x ) When are they equal ρɛ

Related results 1. Gradient Descent Converges to Minimizers By Lee, Simchowitz, Jordan and Recht. 15

Related results 1. Gradient Descent Converges to Minimizers By Lee, Simchowitz, Jordan and Recht. 15 with random initialization, GD almost surely never touches any saddle points, and always converges to local minima.

Related results 1. Gradient Descent Converges to Minimizers By Lee, Simchowitz, Jordan and Recht. 15 with random initialization, GD almost surely never touches any saddle points, and always converges to local minima. 2. The power of normalization: faster evasion of saddle points, Kfir Levy. 16

Related results 1. Gradient Descent Converges to Minimizers By Lee, Simchowitz, Jordan and Recht. 15 with random initialization, GD almost surely never touches any saddle points, and always converges to local minima. 2. The power of normalization: faster evasion of saddle points, Kfir Levy. 16 Normalized gradient can escape saddle points in O(d 3 poly(1/ɛ)), slower than [Jin et al 2017], faster than [Ge et al 2015], but still polynomial in d.

Main theorem in [Jin et al 2017] Theorem (Main Theorem) Assume function f is l-smooth and ρ-hessian Lipschitz, (α, γ, ɛ, ζ)-strict saddle. There exists an absolute constant c max such that, for any δ > 0, f f (x 0 ) f, and constant c c max, ɛ = min{ɛ, γ2 ρ }, PGD(c) will output a point ζ-close to a local minimum, with probability 1 δ, and terminate in the following number of iterations: O ( l(f (x0 ) f ) ɛ 2 log 4 ( dl f ɛ 2 δ ))

Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres.

Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps.

Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps. f thres /t thres on average each step.

Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps. f thres /t thres on average each step. Trap:

Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps. f thres /t thres on average each step. Trap: The algorithm can t do progress and escape forever, because it s bounded!

Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps. f thres /t thres on average each step. Trap: The algorithm can t do progress and escape forever, because it s bounded! When it stops: perturbation happened t thres steps ago, but f is decreased for less than f thres

Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps. f thres /t thres on average each step. Trap: The algorithm can t do progress and escape forever, because it s bounded! When it stops: perturbation happened t thres steps ago, but f is decreased for less than f thres That means, f (x) gthres before perturbation, and whp there is no eigenvalue γ.

Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps. f thres /t thres on average each step. Trap: The algorithm can t do progress and escape forever, because it s bounded! When it stops: perturbation happened t thres steps ago, but f is decreased for less than f thres That means, f (x) gthres before perturbation, and whp there is no eigenvalue γ. So it s a local minimum!

Progress Lemma If f is l-smooth, then for GD with step size η < 1 l, we have: f (x t+1 ) f (x t ) η 2 f (x t) 2 Proof. f (x t+1 ) f (x t ) + f (x t ) (x t+1 x t ) + l 2 x t+1 x t 2 = f (x t ) η f (x t ) 2 + η2 l 2 f (x t) 2 f (x t ) η 2 f (x t) 2

Escape: main idea

Escape: main idea

Escape: thin pancake

Main Lemma: measure the width Lemma Suppose we start with point x satisfying following conditions: f ( x) g thres, λ min ( 2 f ( x)) γ Let e 1 the minimum eigenvector. Consider two gradient descent sequences {u t }, {w t }, with initial points u 0, w 0 satisfying : u 0 x r, w 0 = u 0 + µre 1, µ [δ/(2 d), 1] Then, for any stepsize η c max /l, and any T t thres, we have min{f (u T ) f (u 0 ), f (w T ) f (w 0 )} 2.5f thres As long as u 0 w 0 are on e 1, and u 0 w 0 one of them will escape! δr 2, at least d

Main Lemma: measure the width

Escape Case Lemma (Escape case) Suppose we start with point x satisfying following conditions: f ( x) g thres, λ min( 2 f ( x)) γ Let x 0 = x + ξ, where ξ come from the uniform distribution over ball with radius r, and let x t be the iterates of GD from x 0. Then when η < cmax l, with at least probability 1 δ, for any T t thres : f (x T ) f ( x) f thres

Proof of the escape lemma

Proof of the escape lemma By smoothness, the perturbation step does not increase f much: f (x 0 ) f ( x) f ( x) ξ + l 2 ξ 2 1.5f thres

Proof of the escape lemma By smoothness, the perturbation step does not increase f much: f (x 0 ) f ( x) f ( x) ξ + l 2 ξ 2 1.5f thres By the main lemma, for any x 0 X stuck, we know (x 0 ± µre 1 ) X stuck, where µ [δ/(2 d), 1]. Vol(X stuck ) = Vol(B (d 1) x (r)) δr 2 d 2

Proof of the escape lemma By smoothness, the perturbation step does not increase f much: f (x 0 ) f ( x) f ( x) ξ + l 2 ξ 2 1.5f thres By the main lemma, for any x 0 X stuck, we know (x 0 ± µre 1 ) X stuck, where µ [δ/(2 d), 1]. Vol(X stuck ) = Vol(B (d 1) x (r)) δr 2 d 2 Therefore, the probability that we picked a point in X stuck is bounded by Vol(X stuck ) Vol(B (d) x (r))) δ

Proof of the escape lemma Thus, with probability at least 1 δ, x 0 X stuck, and in this case, by the main lemma. f (x T ) f ( x) 2.5f thres + 1.5f thres = f thres

How to prove the main Lemma?

How to prove the main Lemma? If u T does not decrease function value, then {u 0,, u T } are close to x.

How to prove the main Lemma? If u T does not decrease function value, then {u 0,, u T } are close to x. If {u 0,, u T } are close to x, GD on w 0 will decrease the function value.

How to prove the main Lemma? If u T does not decrease function value, then {u 0,, u T } are close to x. If {u 0,, u T } are close to x, GD on w 0 will decrease the function value. We will need the following approximation: f y (x) = f (y) + f (y) (x y) + 1 2 (x y) H(x y) where H = 2 f ( x).

Two lemmas (simplified) Lemma (u T -stuck) There exists absolute constant c max s.t., for any initial point u 0 with u 0 x r, defined { } } T = min inf {t f u0 (u t ) f (u 0 ) 3f thres, t thres t Then, for any η cmax l, we have for all t < T, u t x Φ.

Two lemmas (simplified) Lemma (u T -stuck) There exists absolute constant c max s.t., for any initial point u 0 with u 0 x r, defined { } } T = min inf {t f u0 (u t ) f (u 0 ) 3f thres, t thres t Then, for any η cmax l, we have for all t < T, u t x Φ. Lemma (w T -escape) There exists absolute constant c max s.t., define { } } T = min inf {t f w0 (w t ) f (w 0 ) 3f thres, t thres t then, for any η cmax l, if u t x Φ for t < T, we have T < t thres.

Prove the main lemma

Prove the main lemma Assume x is the origin. Define } T = inf {t f u0 (u t ) f (u 0 ) 3f thres t

Prove the main lemma Assume x is the origin. Define } T = inf {t f u0 (u t ) f (u 0 ) 3f thres t Case T t thres : We know u T 1 Φ by u T -stuck-lemma. By simple calculation, we can show that u T = O(Φ) as well.

Prove the main lemma Assume x is the origin. Define } T = inf {t f u0 (u t ) f (u 0 ) 3f thres t Case T t thres : We know u T 1 Φ by u T -stuck-lemma. By simple calculation, we can show that u T = O(Φ) as well. f (u T ) f (u 0 ) f (u 0 ) (u T u 0 ) + 1 2 (u T u 0) 2 f (u 0 )(u T u 0 ) + ρ 6 u T u 0 3 f u0 (u t ) f (u 0 ) + ρ 2 u 0 x u T u 0 2 + ρ 6 u T u 0 3 2.5f thres

Prove the main lemma Case T > t thres : u t Φ. By u T -stuck-lemma, we know for all t t thres

Prove the main lemma Case T > t thres : By u T -stuck-lemma, we know for all t t thres u t Φ. Using the w T -escape-lemma, we know } T = inf {t f w0 (w t ) f (w 0 ) 3f thres t thres t

Prove the main lemma Case T > t thres : By u T -stuck-lemma, we know for all t t thres u t Φ. Using the w T -escape-lemma, we know } T = inf {t f w0 (w t ) f (w 0 ) 3f thres t thres t Then we may reduce this to the case that T t thres because w, u are interchangeable.

Prove the u T -stuck-lemma

Prove the u T -stuck-lemma Lemma (u T -stuck) There exists absolute constant c max s.t., for any initial point u 0 with u 0 x r, defined { } } T = min inf {t f u0 (u t ) f (u 0 ) 3f thres, t thres t Then, for any η cmax l, we have for all t < T, u t x Φ.

Prove the u T -stuck-lemma Lemma (u T -stuck) There exists absolute constant c max s.t., for any initial point u 0 with u 0 x r, defined { } } T = min inf {t f u0 (u t ) f (u 0 ) 3f thres, t thres t Then, for any η cmax l, we have for all t < T, u t x Φ. We won t move much in large negative eigenvector directions, otherwise it s a lot of progress!

Prove the u T -stuck-lemma Lemma (u T -stuck) There exists absolute constant c max s.t., for any initial point u 0 with u 0 x r, defined { } } T = min inf {t f u0 (u t ) f (u 0 ) 3f thres, t thres t Then, for any η cmax l, we have for all t < T, u t x Φ. We won t move much in large negative eigenvector directions, otherwise it s a lot of progress! Consider B t as u t in the remaining space where eigenvalue γ 100, B t+1 (1 + ηγ 100 ) B t + 2ηg thres

Prove the u T -stuck-lemma Lemma (u T -stuck) There exists absolute constant c max s.t., for any initial point u 0 with u 0 x r, defined { } } T = min inf {t f u0 (u t ) f (u 0 ) 3f thres, t thres t Then, for any η cmax l, we have for all t < T, u t x Φ. We won t move much in large negative eigenvector directions, otherwise it s a lot of progress! Consider B t as u t in the remaining space where eigenvalue γ 100, B t+1 (1 + ηγ 100 ) B t + 2ηg thres If T t thres, we will have (1 + ηγ 100 )T 3, so B T is bounded.

Prove the w T -escape-lemma Lemma (w T -escape) There exists absolute constant c max s.t., define { } } T = min inf {t f w0 (w t ) f (w 0 ) 3f thres, t thres t then, for any η cmax l, if u t x Φ for t < T, we have T < t thres.

Prove the w T -escape-lemma let v t = w t u t

Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress.

Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress. If w t makes no progress, by u T -stuck-lemma, it s still near x.

Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress. If w t makes no progress, by u T -stuck-lemma, it s still near x. Therefore, we always have v t u t + w t 2Φ.

Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress. If w t makes no progress, by u T -stuck-lemma, it s still near x. Therefore, we always have v t u t + w t 2Φ. However, v t is increasing very rapidly. It can t be always small!

Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress. If w t makes no progress, by u T -stuck-lemma, it s still near x. Therefore, we always have v t u t + w t 2Φ. However, v t is increasing very rapidly. It can t be always small! At e 1 direction v 0 has at least δr 2 d

Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress. If w t makes no progress, by u T -stuck-lemma, it s still near x. Therefore, we always have v t u t + w t 2Φ. However, v t is increasing very rapidly. It can t be always small! At e 1 direction v 0 has at least δr 2 d Every time it multiplies by at least 1 + ηγ.

Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress. If w t makes no progress, by u T -stuck-lemma, it s still near x. Therefore, we always have v t u t + w t 2Φ. However, v t is increasing very rapidly. It can t be always small! At e 1 direction v 0 has at least δr 2 d Every time it multiplies by at least 1 + ηγ. In T < t thres, we get v T > 2Φ, so w T made progress!