Mini-Course 1: SGD Escapes Saddle Points Yang Yuan Computer Science Department Cornell University
Gradient Descent (GD) Task: min x f (x) GD does iterative updates x t+1 = x t η t f (x t )
Gradient Descent (GD) Task: min x f (x) GD does iterative updates x t+1 = x t η t f (x t )
Gradient Descent (GD) has at least two problems
Gradient Descent (GD) has at least two problems Computing the full gradient is slow for big data.
Gradient Descent (GD) has at least two problems Computing the full gradient is slow for big data. Stuck at stationary points.
Stochastic Gradient Descent (SGD) Very similar to GD, gradient now has some randomness: x t+1 = x t η t g t, where E[g t ] = f (x t ).
Stochastic Gradient Descent (SGD) Very similar to GD, gradient now has some randomness: x t+1 = x t η t g t, where E[g t ] = f (x t ).
Why do we use SGD? Initially because:
Why do we use SGD? Initially because: Much cheaper to compute using mini-batch
Why do we use SGD? Initially because: Much cheaper to compute using mini-batch Can still converge to global minimum in convex case
Why do we use SGD? Initially because: Much cheaper to compute using mini-batch Can still converge to global minimum in convex case But now people realize:
Why do we use SGD? Initially because: Much cheaper to compute using mini-batch Can still converge to global minimum in convex case But now people realize: Can escape saddle points! (Today s topic)
Why do we use SGD? Initially because: Much cheaper to compute using mini-batch Can still converge to global minimum in convex case But now people realize: Can escape saddle points! (Today s topic) Can escape shallow local minima (Next time s topic, some progress.)
Why do we use SGD? Initially because: Much cheaper to compute using mini-batch Can still converge to global minimum in convex case But now people realize: Can escape saddle points! (Today s topic) Can escape shallow local minima (Next time s topic, some progress.) Can find local minima that generalize well (Not well understood)
Why do we use SGD? Initially because: Much cheaper to compute using mini-batch Can still converge to global minimum in convex case But now people realize: Can escape saddle points! (Today s topic) Can escape shallow local minima (Next time s topic, some progress.) Can find local minima that generalize well (Not well understood) Therefore, it s not only faster, but also works better!
About g t that we use x t+1 = x t η t g t, where E[g t ] = f (x t ).
About g t that we use x t+1 = x t η t g t, where E[g t ] = f (x t ). In practice, g t is obtained by sampling a minibatch of size 128 or 256 from the dataset
About g t that we use x t+1 = x t η t g t, where E[g t ] = f (x t ). In practice, g t is obtained by sampling a minibatch of size 128 or 256 from the dataset To simplify the analysis, we assume where ξ t N(0, I) or B 0 (r) g t = f (x t ) + ξ t
About g t that we use x t+1 = x t η t g t, where E[g t ] = f (x t ). In practice, g t is obtained by sampling a minibatch of size 128 or 256 from the dataset To simplify the analysis, we assume where ξ t N(0, I) or B 0 (r) g t = f (x t ) + ξ t In general, if ξ t has non-negligible components on every direction, the analysis works.
Preliminaries L-Lipschitz, i.e., f (w 1 ) f (w 2 ) L w 1 w 2 2
Preliminaries L-Lipschitz, i.e., f (w 1 ) f (w 2 ) L w 1 w 2 2 l-smoothness: The gradient is l-lipschitz, i.e. f (w 1 ) f (w 2 ) 2 l w 1 w 2 2
Preliminaries L-Lipschitz, i.e., f (w 1 ) f (w 2 ) L w 1 w 2 2 l-smoothness: The gradient is l-lipschitz, i.e. f (w 1 ) f (w 2 ) 2 l w 1 w 2 2 ρ-hessian smoothness: The hessian matrix is ρ-lipschitz, i.e., 2 f (w 1 ) 2 f (w 2 ) sp ρ w 1 w 2 2
Preliminaries L-Lipschitz, i.e., f (w 1 ) f (w 2 ) L w 1 w 2 2 l-smoothness: The gradient is l-lipschitz, i.e. f (w 1 ) f (w 2 ) 2 l w 1 w 2 2 ρ-hessian smoothness: The hessian matrix is ρ-lipschitz, i.e., 2 f (w 1 ) 2 f (w 2 ) sp ρ w 1 w 2 2 We need this because we will use the Hessian at the current spot to approximate the neighborhood
Preliminaries L-Lipschitz, i.e., f (w 1 ) f (w 2 ) L w 1 w 2 2 l-smoothness: The gradient is l-lipschitz, i.e. f (w 1 ) f (w 2 ) 2 l w 1 w 2 2 ρ-hessian smoothness: The hessian matrix is ρ-lipschitz, i.e., 2 f (w 1 ) 2 f (w 2 ) sp ρ w 1 w 2 2 We need this because we will use the Hessian at the current spot to approximate the neighborhood Then bound the approximation.
Saddle points, and negative eigenvalue
Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0,
Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum.
Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum. If 2 f (w) 0, it s a local maximum.
Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum. If 2 f (w) 0, it s a local maximum. If 2 f (w) has both +/ eigenvalues, it s a saddle point.
Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum. If 2 f (w) 0, it s a local maximum. If 2 f (w) has both +/ eigenvalues, it s a saddle point. Degenerate case: 2 f (w) has eigenvalues equal to 0. It could be either local minimum(maximum)/saddle point.
Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum. If 2 f (w) 0, it s a local maximum. If 2 f (w) has both +/ eigenvalues, it s a saddle point. Degenerate case: 2 f (w) has eigenvalues equal to 0. It could be either local minimum(maximum)/saddle point. f is flat on some directions
Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum. If 2 f (w) 0, it s a local maximum. If 2 f (w) has both +/ eigenvalues, it s a saddle point. Degenerate case: 2 f (w) has eigenvalues equal to 0. It could be either local minimum(maximum)/saddle point. f is flat on some directions SGD is like random walk
Stationary points: saddle points, local minima, local maxima For stationary points f (w) = 0, If 2 f (w) 0, it s a local minimum. If 2 f (w) 0, it s a local maximum. If 2 f (w) has both +/ eigenvalues, it s a saddle point. Degenerate case: 2 f (w) has eigenvalues equal to 0. It could be either local minimum(maximum)/saddle point. f is flat on some directions SGD is like random walk We only consider non-degenerate case!
Strict saddle property f (w) is (α, γ, ɛ, ζ)-strict saddle, if for any w,
Strict saddle property f (w) is (α, γ, ɛ, ζ)-strict saddle, if for any w, f (w) 2 ɛ Which means: Gradient is large
Strict saddle property f (w) is (α, γ, ɛ, ζ)-strict saddle, if for any w, f (w) 2 ɛ or, λ min 2 f (w) γ < 0 Which means: Gradient is large or (stationary point), we have a negative eigenvalue direction to escape
Strict saddle property f (w) is (α, γ, ɛ, ζ)-strict saddle, if for any w, f (w) 2 ɛ or, λ min 2 f (w) γ < 0 or, there exists w such that w w 2 ζ, and the region centered w with radius 2ζ is α-strongly convex. Which means: Gradient is large or (stationary point), we have a negative eigenvalue direction to escape or (stationary point, no negative eigenvalues), we are pretty close to a local minimum.
Strict saddle functions are everywhere Orthogonal tensor decomposition [Ge et al 2015] Deep linear (residual) networks [Kawaguchi 2016], [Hardt and Ma 2016] Matrix completion [Ge et al 2016] Generalized phase retrieval problem [Sun et al 2016] Low rank matrix recovery [Bhojanapalli et al 2016]
Strict saddle functions are everywhere Orthogonal tensor decomposition [Ge et al 2015] Deep linear (residual) networks [Kawaguchi 2016], [Hardt and Ma 2016] Matrix completion [Ge et al 2016] Generalized phase retrieval problem [Sun et al 2016] Low rank matrix recovery [Bhojanapalli et al 2016] Moreover, in these problems, all local minima are equally good!
Strict saddle functions are everywhere Orthogonal tensor decomposition [Ge et al 2015] Deep linear (residual) networks [Kawaguchi 2016], [Hardt and Ma 2016] Matrix completion [Ge et al 2016] Generalized phase retrieval problem [Sun et al 2016] Low rank matrix recovery [Bhojanapalli et al 2016] Moreover, in these problems, all local minima are equally good! That means,
Strict saddle functions are everywhere Orthogonal tensor decomposition [Ge et al 2015] Deep linear (residual) networks [Kawaguchi 2016], [Hardt and Ma 2016] Matrix completion [Ge et al 2016] Generalized phase retrieval problem [Sun et al 2016] Low rank matrix recovery [Bhojanapalli et al 2016] Moreover, in these problems, all local minima are equally good! That means, SGD escapes all saddle points
Strict saddle functions are everywhere Orthogonal tensor decomposition [Ge et al 2015] Deep linear (residual) networks [Kawaguchi 2016], [Hardt and Ma 2016] Matrix completion [Ge et al 2016] Generalized phase retrieval problem [Sun et al 2016] Low rank matrix recovery [Bhojanapalli et al 2016] Moreover, in these problems, all local minima are equally good! That means, SGD escapes all saddle points So, SGD arrives one local minimum global minimum!
Strict saddle functions are everywhere Orthogonal tensor decomposition [Ge et al 2015] Deep linear (residual) networks [Kawaguchi 2016], [Hardt and Ma 2016] Matrix completion [Ge et al 2016] Generalized phase retrieval problem [Sun et al 2016] Low rank matrix recovery [Bhojanapalli et al 2016] Moreover, in these problems, all local minima are equally good! That means, SGD escapes all saddle points So, SGD arrives one local minimum global minimum! One popular way to prove SGD solves the problem.
Main Results [Ge et al 2015] says, whp, SGD will escape all saddle points, and converge to a local minimum. The convergence time has polynomial dependency in dimension d.
Main Results [Ge et al 2015] says, whp, SGD will escape all saddle points, and converge to a local minimum. The convergence time has polynomial dependency in dimension d. [Jin et al 2017] says, whp, PGD (a variant of SGD) will escape all saddle points, and converge to a local minimum much faster. The dependence in d is logarithmic.
Main Results [Ge et al 2015] says, whp, SGD will escape all saddle points, and converge to a local minimum. The convergence time has polynomial dependency in dimension d. [Jin et al 2017] says, whp, PGD (a variant of SGD) will escape all saddle points, and converge to a local minimum much faster. The dependence in d is logarithmic. Same proof framework. We ll mainly look at the new result.
Description of PGD Do the following iteratively:
Description of PGD Do the following iteratively: Do a gradient descent step: x t+1 = x t η f (x t )
Description of PGD Do the following iteratively: If f (x t ) g thres, and last perturbed time is > t thres steps before, do random perturbation (ball) Do a gradient descent step: x t+1 = x t η f (x t )
Description of PGD Do the following iteratively: If f (x t ) g thres, and last perturbed time is > t thres steps before, do random perturbation (ball) If perturbation happened t thres steps ago, but f is decreased for less than f thres, return the value before last perturbation Do a gradient descent step: x t+1 = x t η f (x t )
Description of PGD Do the following iteratively: If f (x t ) g thres, and last perturbed time is > t thres steps before, do random perturbation (ball) If perturbation happened t thres steps ago, but f is decreased for less than f thres, return the value before last perturbation Do a gradient descent step: x t+1 = x t η f (x t ) A few Remarks:
Description of PGD Do the following iteratively: If f (x t ) g thres, and last perturbed time is > t thres steps before, do random perturbation (ball) If perturbation happened t thres steps ago, but f is decreased for less than f thres, return the value before last perturbation Do a gradient descent step: x t+1 = x t η f (x t ) A few Remarks: Unfortunately.. Not a fast algorithm because of GD!
Description of PGD Do the following iteratively: If f (x t ) g thres, and last perturbed time is > t thres steps before, do random perturbation (ball) If perturbation happened t thres steps ago, but f is decreased for less than f thres, return the value before last perturbation Do a gradient descent step: x t+1 = x t η f (x t ) A few Remarks: Unfortunately.. Not a fast algorithm because of GD! η = c l. g thres, t thres, f thres depends on a constant c, as well as other parameters.
Main theorem in [Jin et al 2017] Theorem (Main Theorem) Assume function f is l-smooth and ρ-hessian Lipschitz, (α, γ, ɛ, ζ)-strict saddle. There exists an absolute constant c max such that, for any δ > 0, f f (x 0 ) f, and constant c c max, ɛ = min{ɛ, γ2 ρ }, PGD(c) will output a point ζ-close to a local minimum, with probability 1 δ, and terminate in the following number of iterations: O ( l(f (x0 ) f ) ɛ 2 log 4 ( dl f ɛ 2 δ ))
Main theorem in [Jin et al 2017] Theorem (Main Theorem) Assume function f is l-smooth and ρ-hessian Lipschitz, (α, γ, ɛ, ζ)-strict saddle. There exists an absolute constant c max such that, for any δ > 0, f f (x 0 ) f, and constant c c max, ɛ = min{ɛ, γ2 ρ }, PGD(c) will output a point ζ-close to a local minimum, with probability 1 δ, and terminate in the following number of iterations: O ( l(f (x0 ) f ) ɛ 2 log 4 ( dl f ɛ 2 δ )) If could show SGD has similar property, would be great!
Main theorem in [Jin et al 2017] Theorem (Main Theorem) Assume function f is l-smooth and ρ-hessian Lipschitz, (α, γ, ɛ, ζ)-strict saddle. There exists an absolute constant c max such that, for any δ > 0, f f (x 0 ) f, and constant c c max, ɛ = min{ɛ, γ2 ρ }, PGD(c) will output a point ζ-close to a local minimum, with probability 1 δ, and terminate in the following number of iterations: O ( l(f (x0 ) f ) ɛ 2 log 4 ( dl f ɛ 2 δ )) If could show SGD has similar property, would be great! The convergence rate is almost optimal.
More general version: why it s fast Theorem (A more general version) Assume function f is l-smooth and ρ-hessian Lipschitz. There exists an absolute constant c max such that, for any δ > 0, f f (x 0 ) f, and constant c c max, ɛ l2 ρ, PGD(c) will output a point ζ-close to an ɛ-second-order stationary point, with probability 1 δ, and terminate in the following number of iterations: ( l(f (x0 ) f ( )) ) O ɛ 2 log 4 dl f ɛ 2 δ Essentially saying the same thing. If f is not strict saddle, only ɛ-second-order stationary point (instead of local minimum) is guaranteed.
ɛ-stationary points ɛ-first-order stationary point: f (x) ɛ
ɛ-stationary points ɛ-first-order stationary point: f (x) ɛ ɛ-second-order stationary point: f (x) ɛ, λ min ( 2 f (x)) ρɛ
ɛ-stationary points ɛ-first-order stationary point: f (x) ɛ ɛ-second-order stationary point: f (x) ɛ, λ min ( 2 f (x)) ρɛ If l-smooth, λmin ( 2 f (x)) l.
ɛ-stationary points ɛ-first-order stationary point: f (x) ɛ ɛ-second-order stationary point: f (x) ɛ, λ min ( 2 f (x)) ρɛ If l-smooth, λmin ( 2 f (x)) l. For any ɛ > l 2 ρ, an ɛ-first-order stationary point in a l-smooth function is a l2 ρ -second-order stationary point
ɛ-stationary points ɛ-first-order stationary point: f (x) ɛ ɛ-second-order stationary point: f (x) ɛ, λ min ( 2 f (x)) ρɛ If l-smooth, λmin ( 2 f (x)) l. For any ɛ > l 2 ρ, an ɛ-first-order stationary point in a l-smooth function is a l2 ρ -second-order stationary point If (α, γ, ɛ, ζ)-strict saddle, and ɛ < γ 2 ρ, then any ɛ-second-order stationary point is a local minimum.
[Nesterov, 1998] Theorem Assume that f is l-smooth. Then for any ɛ > 0, if we run GD with step size η = 1 l and termination condition f (x) ɛ, the output will be ɛ-first-order stationary point, and the algorithm terminates in the following number of iterations: l(f (x 0 ) f ) ɛ 2 [Jin et ( al 2017]: PGD ( converges )) to ɛ-second-order stationary point in O l(f (x0 ) f ) log 4 dl f steps. ɛ 2 ɛ 2 δ Matched up to log factors!
Why ρɛ? If we use third order approximation for x [Nesterov and Polyak, 2006] min { f (x), y x + 12 2 f (x)(y x), y x + ρ6 } y x 2 y denote the answer as T x.
Why ρɛ? If we use third order approximation for x [Nesterov and Polyak, 2006] min { f (x), y x + 12 2 f (x)(y x), y x + ρ6 } y x 2 y denote the answer as T x. Denote distance r = x T x
Why ρɛ? If we use third order approximation for x [Nesterov and Polyak, 2006] min { f (x), y x + 12 2 f (x)(y x), y x + ρ6 } y x 2 y denote the answer as T x. Denote distance r = x T x f (T x ) ρr 2, 2 f (T x ) 3 2 ρri
Why ρɛ? If we use third order approximation for x [Nesterov and Polyak, 2006] min { f (x), y x + 12 2 f (x)(y x), y x + ρ6 } y x 2 y denote the answer as T x. Denote distance r = x T x f (T x ) ρr 2, 2 f (T x ) 3 2 ρri To get a lower bound for r: { } f (T x ) max, 2 ρ 3ρ λ min 2 f (T x )
Why ρɛ? If we use third order approximation for x [Nesterov and Polyak, 2006] min { f (x), y x + 12 2 f (x)(y x), y x + ρ6 } y x 2 y denote the answer as T x. Denote distance r = x T x f (T x ) ρr 2, 2 f (T x ) 3 2 ρri To get a lower bound for r: { } f (T x ) max, 2 ρ 3ρ λ min 2 f (T x ) When are they equal ρɛ
Related results 1. Gradient Descent Converges to Minimizers By Lee, Simchowitz, Jordan and Recht. 15
Related results 1. Gradient Descent Converges to Minimizers By Lee, Simchowitz, Jordan and Recht. 15 with random initialization, GD almost surely never touches any saddle points, and always converges to local minima.
Related results 1. Gradient Descent Converges to Minimizers By Lee, Simchowitz, Jordan and Recht. 15 with random initialization, GD almost surely never touches any saddle points, and always converges to local minima. 2. The power of normalization: faster evasion of saddle points, Kfir Levy. 16
Related results 1. Gradient Descent Converges to Minimizers By Lee, Simchowitz, Jordan and Recht. 15 with random initialization, GD almost surely never touches any saddle points, and always converges to local minima. 2. The power of normalization: faster evasion of saddle points, Kfir Levy. 16 Normalized gradient can escape saddle points in O(d 3 poly(1/ɛ)), slower than [Jin et al 2017], faster than [Ge et al 2015], but still polynomial in d.
Main theorem in [Jin et al 2017] Theorem (Main Theorem) Assume function f is l-smooth and ρ-hessian Lipschitz, (α, γ, ɛ, ζ)-strict saddle. There exists an absolute constant c max such that, for any δ > 0, f f (x 0 ) f, and constant c c max, ɛ = min{ɛ, γ2 ρ }, PGD(c) will output a point ζ-close to a local minimum, with probability 1 δ, and terminate in the following number of iterations: O ( l(f (x0 ) f ) ɛ 2 log 4 ( dl f ɛ 2 δ ))
Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres.
Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps.
Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps. f thres /t thres on average each step.
Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps. f thres /t thres on average each step. Trap:
Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps. f thres /t thres on average each step. Trap: The algorithm can t do progress and escape forever, because it s bounded!
Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps. f thres /t thres on average each step. Trap: The algorithm can t do progress and escape forever, because it s bounded! When it stops: perturbation happened t thres steps ago, but f is decreased for less than f thres
Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps. f thres /t thres on average each step. Trap: The algorithm can t do progress and escape forever, because it s bounded! When it stops: perturbation happened t thres steps ago, but f is decreased for less than f thres That means, f (x) gthres before perturbation, and whp there is no eigenvalue γ.
Proof framework: Progress, Escape and Trap Progress: when f (x) > g thres, f (x) is decreased by at least f thres /t thres. Escape: when f (x) g thres, and λ min 2 f (x) γ, whp function value is decreased by f thres after perturbation+t thres steps. f thres /t thres on average each step. Trap: The algorithm can t do progress and escape forever, because it s bounded! When it stops: perturbation happened t thres steps ago, but f is decreased for less than f thres That means, f (x) gthres before perturbation, and whp there is no eigenvalue γ. So it s a local minimum!
Progress Lemma If f is l-smooth, then for GD with step size η < 1 l, we have: f (x t+1 ) f (x t ) η 2 f (x t) 2 Proof. f (x t+1 ) f (x t ) + f (x t ) (x t+1 x t ) + l 2 x t+1 x t 2 = f (x t ) η f (x t ) 2 + η2 l 2 f (x t) 2 f (x t ) η 2 f (x t) 2
Escape: main idea
Escape: main idea
Escape: thin pancake
Main Lemma: measure the width Lemma Suppose we start with point x satisfying following conditions: f ( x) g thres, λ min ( 2 f ( x)) γ Let e 1 the minimum eigenvector. Consider two gradient descent sequences {u t }, {w t }, with initial points u 0, w 0 satisfying : u 0 x r, w 0 = u 0 + µre 1, µ [δ/(2 d), 1] Then, for any stepsize η c max /l, and any T t thres, we have min{f (u T ) f (u 0 ), f (w T ) f (w 0 )} 2.5f thres As long as u 0 w 0 are on e 1, and u 0 w 0 one of them will escape! δr 2, at least d
Main Lemma: measure the width
Escape Case Lemma (Escape case) Suppose we start with point x satisfying following conditions: f ( x) g thres, λ min( 2 f ( x)) γ Let x 0 = x + ξ, where ξ come from the uniform distribution over ball with radius r, and let x t be the iterates of GD from x 0. Then when η < cmax l, with at least probability 1 δ, for any T t thres : f (x T ) f ( x) f thres
Proof of the escape lemma
Proof of the escape lemma By smoothness, the perturbation step does not increase f much: f (x 0 ) f ( x) f ( x) ξ + l 2 ξ 2 1.5f thres
Proof of the escape lemma By smoothness, the perturbation step does not increase f much: f (x 0 ) f ( x) f ( x) ξ + l 2 ξ 2 1.5f thres By the main lemma, for any x 0 X stuck, we know (x 0 ± µre 1 ) X stuck, where µ [δ/(2 d), 1]. Vol(X stuck ) = Vol(B (d 1) x (r)) δr 2 d 2
Proof of the escape lemma By smoothness, the perturbation step does not increase f much: f (x 0 ) f ( x) f ( x) ξ + l 2 ξ 2 1.5f thres By the main lemma, for any x 0 X stuck, we know (x 0 ± µre 1 ) X stuck, where µ [δ/(2 d), 1]. Vol(X stuck ) = Vol(B (d 1) x (r)) δr 2 d 2 Therefore, the probability that we picked a point in X stuck is bounded by Vol(X stuck ) Vol(B (d) x (r))) δ
Proof of the escape lemma Thus, with probability at least 1 δ, x 0 X stuck, and in this case, by the main lemma. f (x T ) f ( x) 2.5f thres + 1.5f thres = f thres
How to prove the main Lemma?
How to prove the main Lemma? If u T does not decrease function value, then {u 0,, u T } are close to x.
How to prove the main Lemma? If u T does not decrease function value, then {u 0,, u T } are close to x. If {u 0,, u T } are close to x, GD on w 0 will decrease the function value.
How to prove the main Lemma? If u T does not decrease function value, then {u 0,, u T } are close to x. If {u 0,, u T } are close to x, GD on w 0 will decrease the function value. We will need the following approximation: f y (x) = f (y) + f (y) (x y) + 1 2 (x y) H(x y) where H = 2 f ( x).
Two lemmas (simplified) Lemma (u T -stuck) There exists absolute constant c max s.t., for any initial point u 0 with u 0 x r, defined { } } T = min inf {t f u0 (u t ) f (u 0 ) 3f thres, t thres t Then, for any η cmax l, we have for all t < T, u t x Φ.
Two lemmas (simplified) Lemma (u T -stuck) There exists absolute constant c max s.t., for any initial point u 0 with u 0 x r, defined { } } T = min inf {t f u0 (u t ) f (u 0 ) 3f thres, t thres t Then, for any η cmax l, we have for all t < T, u t x Φ. Lemma (w T -escape) There exists absolute constant c max s.t., define { } } T = min inf {t f w0 (w t ) f (w 0 ) 3f thres, t thres t then, for any η cmax l, if u t x Φ for t < T, we have T < t thres.
Prove the main lemma
Prove the main lemma Assume x is the origin. Define } T = inf {t f u0 (u t ) f (u 0 ) 3f thres t
Prove the main lemma Assume x is the origin. Define } T = inf {t f u0 (u t ) f (u 0 ) 3f thres t Case T t thres : We know u T 1 Φ by u T -stuck-lemma. By simple calculation, we can show that u T = O(Φ) as well.
Prove the main lemma Assume x is the origin. Define } T = inf {t f u0 (u t ) f (u 0 ) 3f thres t Case T t thres : We know u T 1 Φ by u T -stuck-lemma. By simple calculation, we can show that u T = O(Φ) as well. f (u T ) f (u 0 ) f (u 0 ) (u T u 0 ) + 1 2 (u T u 0) 2 f (u 0 )(u T u 0 ) + ρ 6 u T u 0 3 f u0 (u t ) f (u 0 ) + ρ 2 u 0 x u T u 0 2 + ρ 6 u T u 0 3 2.5f thres
Prove the main lemma Case T > t thres : u t Φ. By u T -stuck-lemma, we know for all t t thres
Prove the main lemma Case T > t thres : By u T -stuck-lemma, we know for all t t thres u t Φ. Using the w T -escape-lemma, we know } T = inf {t f w0 (w t ) f (w 0 ) 3f thres t thres t
Prove the main lemma Case T > t thres : By u T -stuck-lemma, we know for all t t thres u t Φ. Using the w T -escape-lemma, we know } T = inf {t f w0 (w t ) f (w 0 ) 3f thres t thres t Then we may reduce this to the case that T t thres because w, u are interchangeable.
Prove the u T -stuck-lemma
Prove the u T -stuck-lemma Lemma (u T -stuck) There exists absolute constant c max s.t., for any initial point u 0 with u 0 x r, defined { } } T = min inf {t f u0 (u t ) f (u 0 ) 3f thres, t thres t Then, for any η cmax l, we have for all t < T, u t x Φ.
Prove the u T -stuck-lemma Lemma (u T -stuck) There exists absolute constant c max s.t., for any initial point u 0 with u 0 x r, defined { } } T = min inf {t f u0 (u t ) f (u 0 ) 3f thres, t thres t Then, for any η cmax l, we have for all t < T, u t x Φ. We won t move much in large negative eigenvector directions, otherwise it s a lot of progress!
Prove the u T -stuck-lemma Lemma (u T -stuck) There exists absolute constant c max s.t., for any initial point u 0 with u 0 x r, defined { } } T = min inf {t f u0 (u t ) f (u 0 ) 3f thres, t thres t Then, for any η cmax l, we have for all t < T, u t x Φ. We won t move much in large negative eigenvector directions, otherwise it s a lot of progress! Consider B t as u t in the remaining space where eigenvalue γ 100, B t+1 (1 + ηγ 100 ) B t + 2ηg thres
Prove the u T -stuck-lemma Lemma (u T -stuck) There exists absolute constant c max s.t., for any initial point u 0 with u 0 x r, defined { } } T = min inf {t f u0 (u t ) f (u 0 ) 3f thres, t thres t Then, for any η cmax l, we have for all t < T, u t x Φ. We won t move much in large negative eigenvector directions, otherwise it s a lot of progress! Consider B t as u t in the remaining space where eigenvalue γ 100, B t+1 (1 + ηγ 100 ) B t + 2ηg thres If T t thres, we will have (1 + ηγ 100 )T 3, so B T is bounded.
Prove the w T -escape-lemma Lemma (w T -escape) There exists absolute constant c max s.t., define { } } T = min inf {t f w0 (w t ) f (w 0 ) 3f thres, t thres t then, for any η cmax l, if u t x Φ for t < T, we have T < t thres.
Prove the w T -escape-lemma let v t = w t u t
Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress.
Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress. If w t makes no progress, by u T -stuck-lemma, it s still near x.
Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress. If w t makes no progress, by u T -stuck-lemma, it s still near x. Therefore, we always have v t u t + w t 2Φ.
Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress. If w t makes no progress, by u T -stuck-lemma, it s still near x. Therefore, we always have v t u t + w t 2Φ. However, v t is increasing very rapidly. It can t be always small!
Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress. If w t makes no progress, by u T -stuck-lemma, it s still near x. Therefore, we always have v t u t + w t 2Φ. However, v t is increasing very rapidly. It can t be always small! At e 1 direction v 0 has at least δr 2 d
Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress. If w t makes no progress, by u T -stuck-lemma, it s still near x. Therefore, we always have v t u t + w t 2Φ. However, v t is increasing very rapidly. It can t be always small! At e 1 direction v 0 has at least δr 2 d Every time it multiplies by at least 1 + ηγ.
Prove the w T -escape-lemma let v t = w t u t We want to say for T < t thres, w T made progress. If w t makes no progress, by u T -stuck-lemma, it s still near x. Therefore, we always have v t u t + w t 2Φ. However, v t is increasing very rapidly. It can t be always small! At e 1 direction v 0 has at least δr 2 d Every time it multiplies by at least 1 + ηγ. In T < t thres, we get v T > 2Φ, so w T made progress!