A Tutorial On Backward Propagation Through Time (BPTT In The Gated Recurrent Unit (GRU RNN Minchen Li Department of Computer Science The University of British Columbia minchenl@cs.ubc.ca Abstract In this tutorial, we provide a thorough explanation on how BPTT in GRU 1 is conducted. A MATLAB program which implements the entire BPTT for GRU and the psudo-codes describing the algorithms explicitly will be presented. We provide two algorithms for BPTT, a direct but quadratic time algorithm for easy understanding, and an optimized linear time algorithm. This tutorial starts with a specification of the problem followed by a mathematical derivation before the computational solutions. 1 Specification We want to use a dataset containing n s sentences each with n w words to train a GRU language model, and our vocabulary size is n v. Namely, we have input x R nv nw ns and label y R nv nw ns both representing n s sentences. For simplicity, lets look at one sentence at a time. In one sentence, the one-hot vector x t R nv 1 represents the t th word. For time step t, the GRU unit computes the output ŷ t using the input x t and the previous internal state s t 1 as follows: z t σ(u z x t + W z s t 1 + b z r t σ(u r x t + W r s t 1 + b r h t tanh(u h x t + W h (s t 1 r t + b h s t (1 z t h t + z t s t 1 ŷ t softmax(v s t + b V (1 Here is the vector element-wise multiplication, σ( is the element-wise sigmoid function, and tanh( is the element-wise hyperbolictangent function. The dimensions of the parameters are as follows: U z, U r, U h R ni nv W z, W r, W h R ni ni b z, b r, b h R ni 1 V R nv ni, b V R nv 1 where n i is the internal memory size set by the user. 1 GRU is an improved version of traditional RNN (Recurrent Neural Network, see WildML.com for an introduction. This link also provides an introduction to GRU and some general discussion on BPTT and beyond.
Then for step t, we can calculate the cross entropy loss L t as: ( L t sumofallelements y t log(ŷ t Here log is also an element-wise function. To train the GRU, we want to know the values of all parameters that minimize the total loss L nw t1 L t: argmin L Θ where Θ {U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V }. This is a non-convex problem with huge input data. So people usually use Stochastic Gradient Descent 2 method to solve this problem, which means we need to calculate L/, L/ U r, L/ U h, L/ W z, L/ W r, L/ W h, L/ b z, L/ b r, L/ b h, L/ V, L/ b V given a batch of sentences. (Note that in each step, these parameters stays the same. In this tutorial we consider using only one sentence at a time to make it concise. 2 Derivation The best way to calculate gradients using the Chain Rule from output to input is to first draw the expression graph of the entire model in order to figure out the relations between the output, intermediate results, and the input 3. Here we draw part of the expression graph of GRU in Fig.1. (2 Figure 1: The upper part of expression graph describing the operations of GRU. Note that the subgraph which s t 1 deps on is just like the sub-graph of s t. This is what the red dashed lines mean. With this expression graph, the Chain Rule works if you go backwards along the edges (top-down. If a node X has multiple outgoing edges connecting the target node T, you need to sum over the partial derivatives of each of those outgoing edges to derive the gradient T/ X. We will illustrate the rules in the following paragraphs. Let s take L/ as the example here. Others are just similar. Since L n w t1 L t and the parameters stay the same in each step, we also have L/ n w t1 ( L t/, so let s calculate each L t / indepently and sum them up. 2 See the Wikipedia to get some knowledge about Stochastic Gradient Descent. 3 See colah s blog and Stanford CS231n Course Note for some general introductions. 2
With the Chain Rule, we have: L t L t (3 The first part is just trivial if you know how to differentiate the cross entropy loss function embedded with the softmax function: L t V (ŷ t y t For z/, similarly, some people might just derive: (if they know how to differentiate sigmoid function s ( t (s t 1 h t z t (1 z t x T t (4 Here there are two expressions 1 z and z s t 1 influencing / z as shown in our expression graph. The solution is to derive partial derivatives through each edge and then add them up, which is exactly how we deal with / 1 as you will see in the following paragraphs. However, Eq.4 only calculates one part of the gradient, so we put a bar on top of it, while you may find this very useful in our following calculations. Note that s t 1 also deps on U z, so we can not treat it as a constant here. Moreover, this s t 1 will also introduce the influence of s i, where i 1,..., t 2. So for clearness, we should expand Eq.3 as: L t L t t L t L t i1 t i1 ( st s i s i (( t 1 ji s j+1 si s j where s i / is the gradient of s i with respect to U z while taking s i 1 as a constant, of which a similar example has been shown in Eq.4 for step t. The derivation of / 1 is similar to the derivation of / z as has been discussed above. Since there are four outgoing edges from s t 1 to s t directly and indirectly through z t, r t, and h t in the expression graph, we need to sum all the four partial derivatives together: 1 h t h t ( ht r t h t 1 + z t r t 1 + z t 1 + 1 h t + 1 z t z t 1 + 1 where / 1 is the gradient of s t with respect to s t 1 while taking h t and z t as constants. Similarly, h t / 1 is the gradient of h t with respect to s t 1 while taking r t as a constant. Plugging the intermediate results in the above formula, we get: s t (1 z t (Wr T ((Wh T (1 h h s t 1 r (1 r + ((Wh T (1 h h r t + 1 ( (s t 1 h t z t (1 z t + z W T z Till now, we have covered all the components needed to calculate L t /. The gradient of L t with respect to other parameters are just similar. In the next chapter, we will provide a more machinery view of the calculation - the psudo-code describing the algorithm to calculate the gradients. In the last chapter of this tutorial, we will provide the pure machine representation - a MATLAB program which implements the calculation and verification of BPTT. If you just want to understand the idea behind BPTT and decide to use fully supported auto-differentiation packages (like Theano 4 to build your own GRU, you can stop here. If you need to implement the exact chain rule like us or just curious about what will happen next, get ready to proceed! 4 Theano is a Python library that allows you to define, optimize, and evaluate mathematical expressions involving multi-dimensional arrays efficiently. (5 (6 3
3 Algorithm Here we also only take L/ as the example. We will provide the calculation of all the gradients in the next chapter. We present two algorithms, one direct algorithm as derived previously calculating L t / and sum them up while taking O(n 2 w time, and the other O(n w time algorithm which we will see later. Algorithm 1 A direct but O(n 2 w time algorithm to calculate L/ (and beyond Input: The training data X, Y R nv nw composed of the one-hot column vectors x t, y t R nv 1, t 1, 2,..., n w representing the words in the sentence. Input: A vector s 0 R ni 1 representing the initial internal state of the model (usually set to 0. Input: The parameters Θ {U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V } of the model. Output: The total loss gradient L/. 1: %forward propagate to calculate the internal states S R ni nw, the predictions Ŷ Rnv nw, the losses L mtr R nw 1, and the intermediate results Z, R, C R ni nw of each step: 2: [S, Ŷ, L mtr, Z, R, C] forward(x, Y, Θ, s 0 % forward( can be implemented easily according to Eq.1 and Eq.2 3: du z zeros(n i, n v % initialize a variable du z 4: L mtr / S V T (Ŷ Y % calculate L t/ for t 1, 2,..., n w with one matrix operation 5: for t 1 to n w % calculate each L t / and accumulate 6: for j t to 1 % calculate each ( L t / s j ( s j / and accumulate 7: L t / z j L t / s j (s j 1 h j % s j / z j is (s j 1 h j, L t / s j is calculated in the last inner loop iteration or in Line 4 8: L t / (U z x j + W z s j 1 + b z L t / z j z j (1 z j % σ(x/ x σ(x (1 σ(x ( 9: du z + L t / (U z x j + W z s j 1 + b z x T j % accumulate 10: calculate L t / s j 1 using L t / s j and Eq.6 % for the next inner loop iteration 11: 12: 13: return du z % L/ The above direct algorithm actually follows Eq.5 to calculate L t / and then add them up to form L/ : L n w L t t1 n w t1 n w t1 ( Lt ( Lt t i1 t i1 ( st s i s i (( t 1 ji s j+1 si s j If we just expand L t / to the second line of the above equation and do some reordering, we can get: L n w ( Lt t ( st s i s i t1 n w t1 n w t1 ( t i1 i1 ( t ( Lt i1 4 ( Lt s i s i s i s i
Right now the inner summation keeps the subscript of L t and iterate over s i. If we further expand the inner summation and then sort them to iterate over L i, we get: L For the inner summation of Eq.7, we have: n w it n w t1 ( Li ( n w (( n w it+1 ( n w it+1 it L i st ( Li +1 +1 + L t L i st+1 + L t +1 This just gives us an updating formula to calculate this inner summation for each step t incrementally rather than executing another for loop, thus making it possible for us to implement an O(n w time algorithm! Algorithm 2 An optimized O(n w time algorithm to calculate L/ (and beyond Input: The training data X, Y R nv nw composed of the one-hot column vectors x t, y t R nv 1, t 1, 2,..., n w representing the words in the sentence. Input: A vector s 0 R ni 1 representing the initial internal state of the model (usually set to 0. Input: The parameters Θ {U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V } of the model. Output: The total loss gradient L/. 1: %forward propagate to calculate the internal states S R ni nw, the predictions Ŷ Rnv nw, the losses L mtr R nw 1, and the intermediate results Z, R, C R ni nw of each step: 2: [S, Ŷ, L mtr, Z, R, C] forward(x, Y, Θ, s 0 % forward( can be implemented easily according to Eq.1 and Eq.2 3: du z zeros(n i, n v % initialize a variable du z 4: L mtr / S V T (Ŷ Y % calculate L t/ for t 1, 2,..., n w with one matrix operation 5: for t n w to 1 % calculate each 6: nw nw nw it ( L i/ z t ( nw it ( nw it ( L i/ ( L i and accumulate (7 (8 (s t 1 h t % / z t is (s t 1 h t, it ( L i/ is calculated in the last iteration or in Line 4. (when t n w, it ( L i/ L t / ( nw 7: it ( L nw i/ (U z x t +W z s t 1 +b z it ( L i/ z t z t (1 z t % σ(x/ x σ(x (1 ( σ(x nw 8: du z + it ( L i/ (U z x t + W z s j t + b z x T t % accumulate 9: calculate n w it 1 ( L i/ 1 using Eq.6 and Eq.8 % for the next iteration 10: 11: return du z % L/ 5
4 Implementation Here we provide the MATLAB program which calculates the gradients with respect to all the parameters of GRU using our two proposed algorithms. It also checks the gradients with the numerical results. We will divide our code into two parts, the first part presented below contains the core functions implementing the BPTT of GRU we just derived, the second part is composed of some functions that are less important to the topic of this tutorial. Core Functions 1 % This program t e s t s t h e BPTT p r o c e s s we manually d e v e l o p e d f o r GRU. % We c a l c u l a t e t h e g r a d i e n t s of GRU p a r a m e t e r s with c h a i n r u l e, and t h e n 3 % compare them t o t h e n u m e r i c a l g r a d i e n t s t o check whether our c h a i n r u l e % d e r i v a t i o n i s c o r r e c t. 5 % Here, we p r o v i d e d 2 v e r s i o n s of BPTT, b a c k w a r d d i r e c t ( and backward (. 7 % The f o r m e r one i s t h e d i r e c t i d e a t o c a l c u l a t e g r a d i e n t w i t h i n each s t e p % and add them up (O( s e n t e n c e s i z e ˆ 2 time. The l a t t e r one i s o p t i m i z e d t o 9 % c a l c u l a t e t h e c o n t r i b u t i o n of each s t e p t o t h e o v e r a l l g r a d i e n t, which i s % only O( s e n t e n c e s i z e time. 11 % This i s v ery h e l p f u l f o r p e o p l e who wants t o implement GRU i n C a f f e s i n c e 13 % C a f f e didn t s u p p o r t auto d i f f e r e n t i a t i o n. This i s a l s o very h e l p f u l f o r % t h e p e o p l e who wants t o know t h e d e t a i l s a b o u t B a c k p r o p a g a t i o n Through 15 % Time a l g o r i t h m i n t h e R e c c u r e n t N eural Networks ( such as GRU and LSTM % and a l s o g e t a s e n s e on how auto d i f f e r e n t i a t i o n i s p o s s i b l e. 17 % NOTE: We didn t i n v o l v e SGD t r a i n i n g h e r e. With SGD t r a i n i n g, t h i s 19 % program would become a c o m p l e t e i m p l e m e n t a t i o n of GRU which can be % t r a i n e d with s e q u e n c e d a t a. However, s i n c e t h i s i s only a CPU s e r i a l 21 % Matlab v e r s i o n of GRU, a p p l y i n g i t on l a r g e d a t a s e t s w i l l be d r a m a t i c a l l y % slow. 23 25 % by Minchen Li, a t The U n i v e r s i t y of B r i t i s h Columbia. 2016 04 21 f u n c t i o n testbptt GRU 27 % s e t GRU and d a t a s c a l e v o c a b u l a r y s i z e 6 4 ; 29 imem size 4 ; s e n t e n c e s i z e 2 0 ; % number of words i n a s e n t e n c e 31 %( i n c l u d i n g s t a r t and symbol % s i n c e we w i l l only use one s e n t e n c e f o r t r a i n i n g, 33 % t h i s i s a l s o t h e t o t a l s t e p s d u r i n g t r a i n i n g. 35 [ x y ] g e t T r a i n i n g D a t a ( v o c a b u l a r y s i z e, s e n t e n c e s i z e ; 37 % i n i t i a l i z e p a r a m e t e r s : % m u l t i p l i e r f o r i n p u t x t of i n t e r m e d i a t e v a r i a b l e s 39 U z r and ( imem size, v o c a b u l a r y s i z e ; U r r and ( imem size, v o c a b u l a r y s i z e ; 41 U c r and ( imem size, v o c a b u l a r y s i z e ; % m u l t i p l i e r f o r p e r v i o u s s of i n t e r m e d i a t e v a r i a b l e s 43 W z rand ( imem size, imem size ; W r r a nd ( imem size, imem size ; 45 W c rand ( imem size, imem size ; % b i a s t e r m s of i n t e r m e d i a t e v a r i a b l e s 47 b z rand ( imem size, 1 ; 6
b r r a nd ( imem size, 1 ; 49 b c rand ( imem size, 1 ; % d e c o d e r f o r g e n e r a t i n g o u t p u t 51 V rand ( v o c a b u l a r y s i z e, imem size ; b V r and ( v o c a b u l a r y s i z e, 1 ; % b i a s of d e c o d e r 53 % p r e v i o u s s of s t e p 1 s 0 r a nd ( imem size, 1 ; 55 % c a l c u l a t e and check g r a d i e n t 57 t i c [ dv, db V, du z, du r, du c, dw z, dw r, dw c, db z, db r, db c, d s 0 ] 59 b a c k w a r d d i r e c t ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0 ; t o c 61 t i c checkgradient GRU ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0, 63 dv, db V, du z, du r, du c, dw z, dw r, dw c, db z, db r, db c, d s 0 ; t o c 65 t i c 67 [ dv, db V, du z, du r, du c, dw z, dw r, dw c, db z, db r, db c, d s 0 ] backward ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0 ; 69 t o c t i c 71 checkgradient GRU ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0, dv, db V, du z, du r, du c, dw z, dw r, dw c, db z, db r, db c, d s 0 ; 73 t o c 75 % Forward p r o p a g a t e c a l c u l a t e s, y h a t, l o s s and i n t e r m e d i a t e v a r i a b l e s f o r each s t e p 77 f u n c t i o n [ s, y h a t, L, z, r, c ] f o r w a r d ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0 79 % c o u n t s i z e s [ v o c a b u l a r y s i z e, s e n t e n c e s i z e ] s i z e ( x ; 81 imem size s i z e (V, 2 ; 83 % i n i t i a l i z e r e s u l t s s z e r o s ( imem size, s e n t e n c e s i z e ; 85 y h a t z e r o s ( v o c a b u l a r y s i z e, s e n t e n c e s i z e ; L z e r o s ( s e n t e n c e s i z e, 1 ; 87 z z e r o s ( imem size, s e n t e n c e s i z e ; r z e r o s ( imem size, s e n t e n c e s i z e ; 89 c z e r o s ( imem size, s e n t e n c e s i z e ; 91 % c a l c u l a t e r e s u l t f o r s t e p 1 s i n c e s 0 i s n o t i n s z ( :, 1 sigmoid ( U z x ( :, 1 + W z s 0 + b z ; 93 r ( :, 1 sigmoid ( U r x ( :, 1 + W r s 0 + b r ; c ( :, 1 t a n h ( U c x ( :, 1 + W c ( s 0. r ( :, 1 + b c ; 95 s ( :, 1 (1 z ( :, 1. c ( :, 1 + z ( :, 1. s 0 ; y h a t ( :, 1 softmax (V s ( :, 1 + b V ; 97 L ( 1 sum( y ( :, 1. l o g ( y h a t ( :, 1 ; % c a l c u l a t e r e s u l t s f o r s t e p 2 s e n t e n c e s i z e s i m i l a r l y 99 f o r wordi 2 : s e n t e n c e s i z e z ( :, wordi sigmoid ( U z x ( :, wordi + W z s ( :, wordi 1 + b z ; 101 r ( :, wordi sigmoid ( U r x ( :, wordi + W r s ( :, wordi 1 + b r ; c ( :, wordi t a n h ( U c x ( :, wordi + W c ( s ( :, wordi 1. r ( :, wordi + b c ; 7
103 s ( :, wordi (1 z ( :, wordi. c ( :, wordi + z ( :, wordi. s ( :, wordi 1 ; y h a t ( :, wordi softmax (V s ( :, wordi + b V ; 105 L ( wordi sum( y ( :, wordi. l o g ( y h a t ( :, wordi ; 107 109 % Backward p r o p a g a t e t o c a l c u l a t e g r a d i e n t u s i n g c h a i n r u l e % (O( s e n t e n c e s i z e t ime 111 f u n c t i o n [ dv, db V, du z, du r, du c, dw z, dw r, dw c, db z, db r, db c, d s 0 ] backward ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0 113 % f o r w a r d p r o p a g a t e t o g e t t h e i n t e r m e d i a t e and o u t p u t r e s u l t s [ s, y h a t, L, z, r, c ] f o r w a r d ( x, y, U z, U r, U c, W z, W r, W c, 115 b z, b r, b c, V, b V, s 0 ; % c o u n t s e n t e n c e s i z e 117 [, s e n t e n c e s i z e ] s i z e ( x ; 119 % c a l c u l a t e g r a d i e n t u s i n g c h a i n r u l e d e l t a y y h a t y ; 121 db V sum ( d e l t a y, 2 ; 123 dv z e r o s ( s i z e (V ; f o r wordi 1 : s e n t e n c e s i z e 125 dv dv + d e l t a y ( :, wordi s ( :, wordi ; 127 d s 0 z e r o s ( s i z e ( s 0 ; 129 du c z e r o s ( s i z e ( U c ; du r z e r o s ( s i z e ( U r ; 131 du z z e r o s ( s i z e ( U z ; dw c z e r o s ( s i z e ( W c ; 133 dw r z e r o s ( s i z e ( W r ; dw z z e r o s ( s i z e ( W z ; 135 db z z e r o s ( s i z e ( b z ; d b r z e r o s ( s i z e ( b r ; 137 db c z e r o s ( s i z e ( b c ; d s s i n g l e V d e l t a y ; 139 % c a l c u l a t e t h e d e r i v a t i v e c o n t r i b u t i o n of each s t e p and add them up d s c u r z e r o s ( s i z e ( d s s i n g l e, 1, 1 ; 141 f o r wordj s e n t e n c e s i z e : 1:2 d s c u r d s c u r + d s s i n g l e ( :, wordj ; 143 d s c u r b k d s c u r ; 145 d t a n h I n p u t ( d s c u r. (1 z ( :, wordj. (1 c ( :, wordj. c ( :, wordj ; db c db c + d t a n h I n p u t ; 147 du c du c + d t a n h I n p u t x ( :, wordj ; %c o u l d be a c c e l e r a t e d by a v o i d i n g add 0 dw c dw c + d t a n h I n p u t ( s ( :, wordj 1. r ( :, wordj ; 149 d s r W c d t a n h I n p u t ; d s c u r d s r. r ( :, wordj ; 151 d s i g I n p u t r d s r. s ( :, wordj 1. r ( :, wordj. (1 r ( :, wordj ; d b r d b r + d s i g I n p u t r ; 153 du r du r + d s i g I n p u t r x ( :, wordj ; %c o u l d be a c c e l e r a t e d by a v o i d i n g add 0 dw r dw r + d s i g I n p u t r s ( :, wordj 1 ; 155 d s c u r d s c u r + W r d s i g I n p u t r ; 157 d s c u r d s c u r + d s c u r b k. z ( :, wordj ; dz d s c u r b k. ( s ( :, wordj 1 c ( :, wordj ; 159 d s i g I n p u t z dz. z ( :, wordj. (1 z ( :, wordj ; db z db z + d s i g I n p u t z ; 8
161 du z du z + d s i g I n p u t z x ( :, wordj ; %c o u l d be a c c e l e r a t e d by a v o i d i n g add 0 dw z dw z + d s i g I n p u t z s ( :, wordj 1 ; 163 d s c u r d s c u r + W z d s i g I n p u t z ; 165 % s 1 167 d s c u r d s c u r + d s s i n g l e ( :, 1 ; 169 d t a n h I n p u t ( d s c u r. (1 z ( :, 1. (1 c ( :, 1. c ( :, 1 ; db c db c + d t a n h I n p u t ; 171 du c du c + d t a n h I n p u t x ( :, 1 ; %c o u l d be a c c e l e r a t e d by a v o i d i n g add 0 dw c dw c + d t a n h I n p u t ( s 0. r ( :, 1 ; 173 d s r W c d t a n h I n p u t ; d s 0 d s 0 + d s r. r ( :, 1 ; 175 d s i g I n p u t r d s r. s 0. r ( :, 1. (1 r ( :, 1 ; d b r d b r + d s i g I n p u t r ; 177 du r du r + d s i g I n p u t r x ( :, 1 ; %c o u l d be a c c e l e r a t e d by a v o i d i n g add 0 dw r dw r + d s i g I n p u t r s 0 ; 179 d s 0 d s 0 + W r d s i g I n p u t r ; 181 d s 0 d s 0 + d s c u r. z ( :, 1 ; dz d s c u r. ( s 0 c ( :, 1 ; 183 d s i g I n p u t z dz. z ( :, 1. (1 z ( :, 1 ; db z db z + d s i g I n p u t z ; 185 du z du z + d s i g I n p u t z x ( :, 1 ; %c o u l d be a c c e l e r a t e d by a v o i d i n g add 0 dw z dw z + d s i g I n p u t z s 0 ; 187 d s 0 d s 0 + W z d s i g I n p u t z ; 189 % A more d i r e c t view of backward p r o p a g a t e t o c a l c u l a t e g r a d i e n t u s i n g 191 % c h a i n r u l e. (O( s e n t e n c e s i z e ˆ 2 time % I n s t e a d of c a l c u l a t i n g how much c o n t r i b u t i o n of d e r i v a t i v e each s t e p has, 193 % h e r e we c a l c u l a t e t h e g r a d i e n t w i t h i n e v e r y s t e p. f u n c t i o n [ dv, db V, du z, du r, du c, dw z, dw r, dw c, db z, db r, db c, d s 0 ] 195 b a c k w a r d d i r e c t ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0 % f o r w a r d p r o p a g a t e t o g e t t h e i n t e r m e d i a t e and o u t p u t r e s u l t s 197 [ s, y h a t, L, z, r, c ] f o r w a r d ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0 ; 199 % c o u n t s e n t e n c e s i z e [, s e n t e n c e s i z e ] s i z e ( x ; 201 % c a l c u l a t e g r a d i e n t u s i n g c h a i n r u l e 203 d e l t a y y h a t y ; db V sum ( d e l t a y, 2 ; 205 dv z e r o s ( s i z e (V ; 207 f o r wordi 1 : s e n t e n c e s i z e dv dv + d e l t a y ( :, wordi s ( :, wordi ; 209 211 d s 0 z e r o s ( s i z e ( s 0 ; du c z e r o s ( s i z e ( U c ; 213 du r z e r o s ( s i z e ( U r ; du z z e r o s ( s i z e ( U z ; 215 dw c z e r o s ( s i z e ( W c ; dw r z e r o s ( s i z e ( W r ; 217 dw z z e r o s ( s i z e ( W z ; 9
db z z e r o s ( s i z e ( b z ; 219 d b r z e r o s ( s i z e ( b r ; db c z e r o s ( s i z e ( b c ; 221 d s s i n g l e V d e l t a y ; % c a l c u l a t e t h e d e r i v a t i v e s i n each s t e p and add them up 223 f o r wordi 1 : s e n t e n c e s i z e d s c u r d s s i n g l e ( :, wordi ; 225 % s i n c e i n each s t e p t, t h e d e r i v a t i v e s deps on s 0 s t, % we need t o t r a c e back from t o t 0 each time 227 f o r wordj wordi : 1:2 d s c u r b k d s c u r ; 229 d t a n h I n p u t ( d s c u r. (1 z ( :, wordj. (1 c ( :, wordj. c ( :, wordj ; 231 db c db c + d t a n h I n p u t ; du c du c + d t a n h I n p u t x ( :, wordj ; %c o u l d be a c c e l e r a t e d by a v o i d i n g add 0 233 dw c dw c + d t a n h I n p u t ( s ( :, wordj 1. r ( :, wordj ; d s r W c d t a n h I n p u t ; 235 d s c u r d s r. r ( :, wordj ; d s i g I n p u t r d s r. s ( :, wordj 1. r ( :, wordj. (1 r ( :, wordj ; 237 d b r d b r + d s i g I n p u t r ; du r du r + d s i g I n p u t r x ( :, wordj ; %c o u l d be a c c e l e r a t e d by a v o i d i n g add 0 239 dw r dw r + d s i g I n p u t r s ( :, wordj 1 ; d s c u r d s c u r + W r d s i g I n p u t r ; 241 d s c u r d s c u r + d s c u r b k. z ( :, wordj ; 243 dz d s c u r b k. ( s ( :, wordj 1 c ( :, wordj ; d s i g I n p u t z dz. z ( :, wordj. (1 z ( :, wordj ; 245 db z db z + d s i g I n p u t z ; du z du z + d s i g I n p u t z x ( :, wordj ; %c o u l d be a c c e l e r a t e d by a v o i d i n g add 0 247 dw z dw z + d s i g I n p u t z s ( :, wordj 1 ; d s c u r d s c u r + W z d s i g I n p u t z ; 249 251 % s 1 d t a n h I n p u t ( d s c u r. (1 z ( :, 1. (1 c ( :, 1. c ( :, 1 ; 253 db c db c + d t a n h I n p u t ; du c du c + d t a n h I n p u t x ( :, 1 ; %c o u l d be a c c e l e r a t e d by a v o i d i n g add 0 255 dw c dw c + d t a n h I n p u t ( s 0. r ( :, 1 ; d s r W c d t a n h I n p u t ; 257 d s 0 d s 0 + d s r. r ( :, 1 ; d s i g I n p u t r d s r. s 0. r ( :, 1. (1 r ( :, 1 ; 259 d b r d b r + d s i g I n p u t r ; du r du r + d s i g I n p u t r x ( :, 1 ; %c o u l d be a c c e l e r a t e d by a v o i d i n g add 0 261 dw r dw r + d s i g I n p u t r s 0 ; d s 0 d s 0 + W r d s i g I n p u t r ; 263 d s 0 d s 0 + d s c u r. z ( :, 1 ; 265 dz d s c u r. ( s 0 c ( :, 1 ; d s i g I n p u t z dz. z ( :, 1. (1 z ( :, 1 ; 267 db z db z + d s i g I n p u t z ; du z du z + d s i g I n p u t z x ( :, 1 ; %c o u l d be a c c e l e r a t e d by a v o i d i n g add 0 269 dw z dw z + d s i g I n p u t z s 0 ; d s 0 d s 0 + W z d s i g I n p u t z ; 271 273 % Sigmoid f u n c t i o n f o r n e u r a l network 275 f u n c t i o n v a l sigmoid ( x 10
277 v a l sigmf ( x, [ 1 0 ] ; testbptt GRU.m Less Important Functions 1 % Fake a t r a i n i n g d a t a s e t : g e n e r a t e only one s e n t e n c e f o r t r a i n i n g. %!!! Only f o r t e s t i n g. Needs t o be changed t o r e a d i n t r a i n i n g d a t a from f i l e s. 3 f u n c t i o n [ x t, y t ] g e t T r a i n i n g D a t a ( v o c a b u l a r y s i z e, s e n t e n c e s i z e a s s e r t ( v o c a b u l a r y s i z e > 2 ; % f o r s t a r t and of s e n t e n c e symbol 5 a s s e r t ( s e n t e n c e s i z e > 0 ; 7 % d e f i n e s t a r t and of s e n t e n c e i n t h e v o c a b u l a r y SENTENCE START z e r o s ( v o c a b u l a r y s i z e, 1 ; 9 SENTENCE START ( 1 1 ; SENTENCE END z e r o s ( v o c a b u l a r y s i z e, 1 ; 11 SENTENCE END ( 2 1 ; 13 % g e n e r a t e s e n t e n c e : x t z e r o s ( v o c a b u l a r y s i z e, s e n t e n c e s i z e 1 ; % l e a v e one s l o t f o r SENTENCE START 15 f o r wordi 1 : s e n t e n c e s i z e 1 % g e n e r a t e a random word e x c l u d e s s t a r t and symbol 17 x t ( r a n d i ( v o c a b u l a r y s i z e 2,1,1 +2, wordi 1 ; 19 y t [ x t, SENTENCE END ] ; % t r a i n i n g o u t p u t x t [SENTENCE START, x t ] ; % t r a i n i n g i n p u t 21 23 % Use n u m e r i c a l d i f f e r e n t i a t i o n t o a p p r o x i m a t e t h e g r a d i e n t of each % p a r a m e t e r and c a l c u l a t e t h e d i f f e r e n c e between t h e s e n u m e r i c a l r e s u l t s 25 % and our r e s u l t s c a l c u l a t e d by a p p l y i n g c h a i n r u l e. f u n c t i o n checkgradient GRU ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0, 27 dv, db V, du z, du r, du c, dw z, dw r, dw c, db z, db r, db c, d s 0 % Here we use t h e c e n t r e d i f f e r e n c e f o r m u l a : 29 % df ( x / dx ( f ( x+h f ( x h / (2 h % I t i s a second o r d e r a c c u r a t e method with e r r o r bounded by O( h ˆ 2 31 h 1e 5; 33 % NOTE: h couldn t be t o o l a r g e or t o o s m a l l s i n c e l a r g e h w i l l % i n t r o d u c e b i g g e r t r u n c a t i o n e r r o r and s m a l l h w i l l i n t r o d u c e b i g g e r 35 % r o u n d o f f e r r o r. 37 d V n umerical z e r o s ( s i z e ( dv ; % C a l c u l a t e p a r t i a l d e r i v a t i v e e l e m e n t by e l e m e n t 39 f o r rowi 1 : s i z e ( dv numerical, 1 f o r c o l I 1 : s i z e ( dv numerical, 2 41 V plus V; V plus ( rowi, c o l I V plus ( rowi, c o l I + h ; 43 V minus V; V minus ( rowi, c o l I V minus ( rowi, c o l I h ; 45 [,, L p l u s ] f o r w a r d ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c, V plus, b V, s 0 ; 47 [,, L minus ] f o r w a r d ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c, V minus, b V, s 0 ; 49 d V n umerical ( rowi, c o l I ( sum ( L p l u s sum ( L minus / 2 / h ; 51 11
d i s p l a y ( sum ( sum ( abs ( dv numerical dv. / ( abs ( dv n umerical +h, 53 dv r e l a t i v e e r r o r ; % p r e v e n t d i v i d i n g by 0 by adding h 55 d U c n u m e r i c a l z e r o s ( s i z e ( du c ; f o r rowi 1 : s i z e ( d U c n u m e r i c a l, 1 57 f o r c o l I 1 : s i z e ( d U c n u m e r i c a l, 2 U c p l u s U c ; 59 U c p l u s ( rowi, c o l I U c p l u s ( rowi, c o l I + h ; U c minus U c ; 61 U c minus ( rowi, c o l I U c minus ( rowi, c o l I h ; [,, L p l u s ] f o r w a r d ( x, y, 63 U z, U r, U c plus, W z, W r, W c, b z, b r, b c, V, b V, s 0 ; [,, L minus ] f o r w a r d ( x, y, 65 U z, U r, U c minus, W z, W r, W c, b z, b r, b c, V, b V, s 0 ; d U c n u m e r i c a l ( rowi, c o l I ( sum ( L p l u s sum ( L minus / 2 / h ; 67 69 d i s p l a y ( sum ( sum ( abs ( d U c n u m e r i c al du c. / ( abs ( d U c n u m e r i c a l +h, du c r e l a t i v e e r r o r ; 71 dw c numerical z e r o s ( s i z e ( dw c ; 73 f o r rowi 1 : s i z e ( dw c numerical, 1 f o r c o l I 1 : s i z e ( dw c numerical, 2 75 W c plus W c ; W c plus ( rowi, c o l I W c plus ( rowi, c o l I + h ; 77 W c minus W c ; W c minus ( rowi, c o l I W c minus ( rowi, c o l I h ; 79 [,, L p l u s ] f o r w a r d ( x, y, U z, U r, U c, W z, W r, W c plus, b z, b r, b c, V, b V, s 0 ; 81 [,, L minus ] f o r w a r d ( x, y, U z, U r, U c, W z, W r, W c minus, b z, b r, b c, V, b V, s 0 ; 83 dw c numerical ( rowi, c o l I ( sum ( L p l u s sum ( L minus / 2 / h ; 85 d i s p l a y ( sum ( sum ( abs ( dw c numerical dw c. / ( abs ( dw c numerical +h, 87 dw c r e l a t i v e e r r o r ; 89 d U r n u m e r i c a l z e r o s ( s i z e ( du r ; f o r rowi 1 : s i z e ( d U r n u m e r i c a l, 1 91 f o r c o l I 1 : s i z e ( d U r n u m e r i c a l, 2 U r p l u s U r ; 93 U r p l u s ( rowi, c o l I U r p l u s ( rowi, c o l I + h ; U r minus U r ; 95 U r minus ( rowi, c o l I U r minus ( rowi, c o l I h ; [,, L p l u s ] f o r w a r d ( x, y, 97 U z, U r p l u s, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0 ; [,, L minus ] f o r w a r d ( x, y, 99 U z, U r minus, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0 ; d U r n u m e r i c a l ( rowi, c o l I ( sum ( L p l u s sum ( L minus / 2 / h ; 101 103 d i s p l a y ( sum ( sum ( abs ( d U r n u m e r i c a l du r. / ( abs ( d U r n u m e r i c a l +h, du r r e l a t i v e e r r o r ; 12
105 d W r n u m e r i c a l z e r o s ( s i z e ( dw r ; 107 f o r rowi 1 : s i z e ( dw r numerical, 1 f o r c o l I 1 : s i z e ( dw r numerical, 2 109 W r p l u s W r ; W r p l u s ( rowi, c o l I W r p l u s ( rowi, c o l I + h ; 111 W r minus W r ; W r minus ( rowi, c o l I W r minus ( rowi, c o l I h ; 113 [,, L p l u s ] f o r w a r d ( x, y, U z, U r, U c, W z, W r plus, W c, b z, b r, b c, V, b V, s 0 ; 115 [,, L minus ] f o r w a r d ( x, y, U z, U r, U c, W z, W r minus, W c, b z, b r, b c, V, b V, s 0 ; 117 d W r n u m e r i c a l ( rowi, c o l I ( sum ( L p l u s sum ( L minus / 2 / h ; 119 d i s p l a y ( sum ( sum ( abs ( dw r numerical dw r. / ( abs ( d W r n u m e r i c a l +h, 121 dw r r e l a t i v e e r r o r ; 123 d U z n u m e r i c a l z e r o s ( s i z e ( du z ; f o r rowi 1 : s i z e ( d U z n u m e r i c a l, 1 125 f o r c o l I 1 : s i z e ( d U z n u m e r i c a l, 2 U z p l u s U z ; 127 U z p l u s ( rowi, c o l I U z p l u s ( rowi, c o l I + h ; U z minus U z ; 129 U z minus ( rowi, c o l I U z minus ( rowi, c o l I h ; [,, L p l u s ] f o r w a r d ( x, y, 131 U z p lus, U r, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0 ; [,, L minus ] f o r w a r d ( x, y, 133 U z minus, U r, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0 ; d U z n u m e r i c a l ( rowi, c o l I ( sum ( L p l u s sum ( L minus / 2 / h ; 135 137 d i s p l a y ( sum ( sum ( abs ( d U z n u m e r i c al du z. / ( abs ( d U z n u m e r i c a l +h, du z r e l a t i v e e r r o r ; 139 dw z numerical z e r o s ( s i z e ( dw z ; 141 f o r rowi 1 : s i z e ( dw z numerical, 1 f o r c o l I 1 : s i z e ( dw z numerical, 2 143 W z plus W z ; W z plus ( rowi, c o l I W z plus ( rowi, c o l I + h ; 145 W z minus W z ; W z minus ( rowi, c o l I W z minus ( rowi, c o l I h ; 147 [,, L p l u s ] f o r w a r d ( x, y, U z, U r, U c, W z plus, W r, W c, b z, b r, b c, V, b V, s 0 ; 149 [,, L minus ] f o r w a r d ( x, y, U z, U r, U c, W z minus, W r, W c, b z, b r, b c, V, b V, s 0 ; 151 dw z numerical ( rowi, c o l I ( sum ( L p l u s sum ( L minus / 2 / h ; 153 d i s p l a y ( sum ( sum ( abs ( dw z numerical dw z. / ( abs ( dw z numerical +h, 155 dw z r e l a t i v e e r r o r ; 157 d b z n u m e r i c a l z e r o s ( s i z e ( db z ; 13
f o r i 1 : l e n g t h ( d b z n u m e r i c a l 159 b z p l u s b z ; b z p l u s ( i b z p l u s ( i + h ; 161 b z m i n u s b z ; b z m i n u s ( i b z m i n u s ( i h ; 163 [,, L p l u s ] f o r w a r d ( x, y, U z, U r, U c, W z, W r, W c, b z p l u s, b r, b c, V, b V, s 0 ; 165 [,, L minus ] f o r w a r d ( x, y, U z, U r, U c, W z, W r, W c, b z minus, b r, b c, V, b V, s 0 ; 167 d b z n u m e r i c a l ( i ( sum ( L p l u s sum ( L minus / 2 / h ; 169 d i s p l a y ( sum ( abs ( d b z n u m e r i c a l db z. / ( abs ( d b z n u m e r i c a l +h, db z r e l a t i v e e r r o r ; 171 d b r n u m e r i c a l z e r o s ( s i z e ( d b r ; 173 f o r i 1 : l e n g t h ( d b r n u m e r i c a l b r p l u s b r ; 175 b r p l u s ( i b r p l u s ( i + h ; b r m i n u s b r ; 177 b r m i n u s ( i b r m i n u s ( i h ; [,, L p l u s ] f o r w a r d ( x, y, 179 U z, U r, U c, W z, W r, W c, b z, b r p l u s, b c, V, b V, s 0 ; [,, L minus ] f o r w a r d ( x, y, 181 U z, U r, U c, W z, W r, W c, b z, b r m i n u s, b c, V, b V, s 0 ; d b r n u m e r i c a l ( i ( sum ( L p l u s sum ( L minus / 2 / h ; 183 d i s p l a y ( sum ( abs ( d b r n u m e r i c a l d b r. / ( abs ( d b r n u m e r i c a l +h, 185 d b r r e l a t i v e e r r o r ; 187 d b c n u m e r i c a l z e r o s ( s i z e ( db c ; f o r i 1 : l e n g t h ( d b c n u m e r i c a l 189 b c p l u s b c ; b c p l u s ( i b c p l u s ( i + h ; 191 b c m i n u s b c ; b c m i n u s ( i b c m i n u s ( i h ; 193 [,, L p l u s ] f o r w a r d ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c p l u s, V, b V, s 0 ; 195 [,, L minus ] f o r w a r d ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c minus, V, b V, s 0 ; 197 d b c n u m e r i c a l ( i ( sum ( L p l u s sum ( L minus / 2 / h ; 199 d i s p l a y ( sum ( abs ( d b c n u m e r i c a l db c. / ( abs ( d b c n u m e r i c a l +h, db c r e l a t i v e e r r o r ; 201 d b V n u m e r i c a l z e r o s ( s i z e ( db V ; 203 f o r i 1 : l e n g t h ( d b V n u m e r i c a l b V p l u s b V ; 205 b V p l u s ( i b V p l u s ( i + h ; b V minus b V ; 207 b V minus ( i b V minus ( i h ; [,, L p l u s ] f o r w a r d ( x, y, 209 U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V plus, s 0 ; [,, L minus ] f o r w a r d ( x, y, 211 U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V minus, s 0 ; d b V n u m e r i c a l ( i ( sum ( L p l u s sum ( L minus / 2 / h ; 213 d i s p l a y ( sum ( abs ( db V numerical db V. / ( abs ( d b V n u m e r i c a l +h, 14
215 db V r e l a t i v e e r r o r ; 217 d s 0 n u m e r i c a l z e r o s ( s i z e ( d s 0 ; f o r i 1 : l e n g t h ( d s 0 n u m e r i c a l 219 s 0 p l u s s 0 ; s 0 p l u s ( i s 0 p l u s ( i + h ; 221 s 0 m i n u s s 0 ; s 0 m i n u s ( i s 0 m i n u s ( i h ; 223 [,, L p l u s ] f o r w a r d ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0 p l u s ; 225 [,, L minus ] f o r w a r d ( x, y, U z, U r, U c, W z, W r, W c, b z, b r, b c, V, b V, s 0 m i n u s ; 227 d s 0 n u m e r i c a l ( i ( sum ( L p l u s sum ( L minus / 2 / h ; 229 d i s p l a y ( sum ( abs ( d s 0 n u m e r i c a l d s 0. / ( abs ( d s 0 n u m e r i c a l +h, d s 0 r e l a t i v e e r r o r ; 231 testbptt GRU.m 15