RNN & LSTM
2020-2ํ๊ธฐ โ์ฐ๊ตฌ์ฐธ์ฌ(CSED339A)โ์์ ์งํํ โStanford CS231โ ์คํฐ๋์์ ๊ณต๋ถํ ๋ด์ฉ์ ์ ๋ฆฌํ ํฌ์คํธ์ ๋๋ค. ์ง์ ์ ์ธ์ ๋ ํ์์ ๋๋ค :)
Introduction to Sequential Model
์ด๋ฒ ํฌ์คํธ์์ ๋ค๋ฃจ๋ ๋ชจ๋ธ์ โ์์โ๋ผ๋ ์ฑ์ง์ ๊ฐ์ง ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๋ชจ๋ธ๋ค์ด๋ค. ์๋ฅผ ๋ค๋ฉด, ๋จ์ด(word)๋ ๋ฌธ์ฅ(sentence), 1๋ ๊ฐ์ ์ฃผ์ ๊ฐ๊ฒฉ ๋ฑ์ด ์์๊ฐ ์ค์ํ๊ฒ ์ฌ๊ฒจ์ง๋ ๋ฐ์ดํฐ๋ค์ด๋ค. ์ด๋ฐ ๋ฐ์ดํฐ๋ฅผ <sequence data>๋ผ๊ณ ํ๋ค.
๊ธฐ์กด <Feed Forward Network>๋ ๊ฐ์ด ์ ๋ฐฉ(ๅๆน)์ผ๋ก๋ง ํ๋ฅด๊ธฐ ๋๋ฌธ์ ๊ณผ๊ฑฐ์ ์ ๋ณด๋ ์ค์ํ <sequential data>๋ฅผ ์ฒ๋ฆฌํ๊ธฐ์๋ ์ ํฉํ์ง ์์๋ค. ๊ทธ๋ฌ๋ <RNN>, <LSTM>์ ๊ฐ์ด ๊ทธ ์ํ์ (sequential)ํ ๊ตฌ์กฐ๋ก ์ธํด ๊ณผ๊ฑฐ ๋ฐ์ดํฐ์ history๋ฅผ โ๊ธฐ์ตโํ ์ ์์ผ๋ฉฐ, <sequence data>๋ฅผ ์ฒ๋ฆฌํ๋ ๊ฒ์ ํนํ๋์ด ์๋ค.
RNN; Recurrent Neural Network
<RNN; Recurrent Neural Network>์์๋ ์๋์ธต์์ ๊ณ์ฐํ ๊ฐ์ด ์ถ๋ ฅ์ธต ๋ฐฉํฅ์ผ๋ก๋ ์ ํ๋์ง๋ง, ๋ค์ ์๋์ธต์ผ๋ก ๋์์ ์๋์ธต์ hidden state์ ์ ์ฅ๋๋ค. ์ด ๊ฐ์ ๋ค์ ์ ๋ ฅ์ ์ฒ๋ฆฌํ ๋ ํ์ฉ๋๋ค!
<RNN>์ ์์ ๊ฐ์ด ํํํ ์๋ ์์ง๋ง, ์๋์ ๊ฐ์ด iteration์ ํ์ด์ ํํํ ์๋ ์๋ค. ์ด๊ฒ์ <Cell>์ด๋ผ๊ณ ํ๋ค.
์๋์ธต์ $t$์์ ์์์ ์ถ๋ ฅ $h_t$๋ฅผ ๊ตฌํ๊ธฐ ์ํด ๋ ๊ฐ์ง ๊ฐ์ ํ์ฉํ๋ค.
- ์ด์ ์์ ์ hidden state $h_{t-1}$
- ํ์ฌ ์์ ์ ์ ๋ ฅ $x_t$
์ด ๋ ๊ฐ์ ๊ฐ์ค์น์ ํจ๊ป ์ ์กฐํฉํด $\tanh$ ํจ์๋ฅผ ํ์ฑ ํจ์ ์ผ์ ์ถ๋ ฅํ๋ฉด, ์๋์ธต์ ์ถ๋ ฅ $h_t$๊ฐ ๋๋ค.
\[h_t = \tanh (W_x x_t + W_h h_{t-1})\]์ถ๋ ฅ์ธต์ ์ด ์๋์ธต์ ๊ฒฐ๊ณผ $h_t$๋ฅผ ์ ๋ ฅ๋ฐ์ ์ถ๋ ฅ์ธต์ ๊ฐ์ค์น $W_y$์ ์กฐํฉํด $y_t$๋ฅผ ์ถ๋ ฅํ๋ค.
\[y_t = W_y h_t\]ํ๋์ <Cell>์ ํ์ต ๋ฐ์ดํฐ์ sequence ํ๋๋ฅผ ์ฝ์ ๋๊น์ง ๋ชจ๋ ๋์ผํ weight ๊ฐ์ ์ฌ์ฉํ๋ค. ์๋ฅผ ๋ค์ด โI love youโ๋ผ๋ ๋ฐ์ดํฐ๋ฅผ ์ ๋ ฅํ๋ฉด, ๊ฐ ๋ฌธ์๋ ๋ชจ๋ ๋์ผํ weight $W_x$, $W_h$, $W_y$๋ฅผ ์ฌ์ฉํ๋ค. ๊ทธ๋์ โI love youโ ๋ฌธ์ฅ ํ๋๊ฐ ํ๋์ ๋ฐ์ดํฐ ์ ๋ ฅ์ด๋ฉฐ, Back-propagation ์ญ์ ์ด ํ ๋ฌธ์ฅ์ ๋ค ์ฝ์ ํ์ ์ผ์ด๋๋ ๊ฒ์ด๋ค.
<RNN>์์์ Back-propagation์ ๊ธฐ์กด์ <Feed Forward Network>์ ๋ฐฉ์๊ณผ๋ ์กฐ๊ธ ๋ค๋ฅด๋ค. <RNN>์์๋ time-step์ผ๋ก <Cell>์ ํผ์น ํ์ Back-prop์ ์ ์ฉํ๋ค. ์ด๋ฅผ <Backprop Through Time; BPTT>๋ผ๊ณ ํ๋ค.
๊ทธ๋ฌ๋ ์์ ๊ฐ์ ๋ฐฉ์์ ์ฒ์๋ถํฐ ๋๊น์ง ์ญ์ ํ๋ฅผ ๋ฌธ์์ ๊ธธ์ด๊ฐ ๊ธธ์ด์ง๋ค๋ฉด ๊ณ์ฐ๋์ด ์ฆ๊ฐํ๋ค๋ ๋ฌธ์ ๊ฐ ์๋ค. ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ ์ฒด ๋ฌธ์ฅ์ ์ผ์ ๊ตฌ๊ฐ์ ๋๋ ์ Backprop์ ์ํํ๊ธฐ๋ ํ๋ค. ์ด๊ฒ์ <Truncated BPTT>๋ผ๊ณ ํ๋ค.
๋, <RNN>๊ณผ ๊ฐ์ <Sequential Model>์ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ ๋์์ ๋ฐ๋ผ 1-to-1, 1-to-many, many-to-1, many-to-many ๋ฑ ๋ค์ํ ํํ๋ก ์กด์ฌํ๋ค.
LSTM; Long Short Term Memory model
<RNN>์ ๊ฒฝ์ฐ hidden state๋ฅผ ํตํด ์ด์ ์ ๋ ฅ์ ๋ํ ์ ๋ณด๋ฅผ ๊ฐ์ง๊ณ ์์ง๋ง, ์ ๋ ฅ ์ํ์ค๊ฐ ๊ธธ์ด์ง์๋ก ์ฑ๋ฅ์ด ๋จ์ด์ง๋ค๋ ๋จ์ ์ด ์กด์ฌํ๋ค. ์ด๋ฅผ <The problem of learning long-term dependencies; ์ฅ๊ธฐ ์์กด์ฑ ๋ฌธ์ >๋ผ๊ณ ํ๋ค. ์ด์ ๋ํด์ ๋์ผํ ๊ฐ์ $W_h$์ ๊ฐ์ ์ฌ๋ฌ๋ฒ ์ฌ์ฉํ๊ฒ ๋๋ฉด์ ๋ฐ์ํ๋ Gradient Exploding ๋๋ Gradient Vanishing์ ์์ธ์ผ๋ก ๊ผฝ๋๋ค.
<LSTM>์ <์ฅ๊ธฐ ์์กด์ฑ ๋ฌธ์ >๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด <RNN>์ ๊ตฌ์กฐ์ Cell state $c_t$๋ฅผ ์ถ๊ฐํ๊ณ ์ฌ๋ฌ ๊ฒ์ดํธ(gate)๋ฅผ ์ถ๊ฐํ ๋ชจ๋ธ์ด๋ค.
RNN(ๅทฆ)๊ณผ LSTM(ๅณ)
<LSTM>์์ cell state $c_t$๋ ์ฅ๊ธฐ ๊ธฐ์ต์ ๋ด๋นํ๋ฉฐ, hidden state $h_t$๋ ๋จ๊ธฐ ๊ธฐ์ต์ ๋ด๋นํ๋ค.
\[c_t = f_t \circ c_{t-1} + i_t \circ g_t\] \[h_t = o_t \circ \tanh(c_t)\]<LSTM>์๋ 4๊ฐ์ง ๊ฒ์ดํธ(gate)๊ฐ ์กด์ฌํ๋ฉด ๊ฐ๊ฐ์ ์๋์ ๊ฐ๋ค. ํธ์๋ฅผ ์ํด bias $b$ ํ ์ ์๋ตํ์๋ค.
1. ์ ๋ ฅ ๊ฒ์ดํธ
\[i_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1})\]2. ๊ฒ์ดํธ ๊ฒ์ดํธ ๐ต
\[g_t = \tanh(W_{xg} x_t + W_{hg}h_{t-1})\]๋ ๊ฒ์ดํธ ๋ชจ๋ ํ์ฌ์ ์ ๋ ฅ $x_t$์ hidden state $h_{t-1}$๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ์ผ๋ฉฐ, ๋ค๋ฅธ ์ ์ ํ์ฑํ ํจ์ ๋ฟ์ด๋ค. ์ด ๋ ๊ฒ์ดํธ๋ฅผ ํตํด ํ์ฌ์ ์ ๋ ฅ์์ ๊ธฐ์ตํ ์ ๋ณด์ ์์ ์ ํ๋ค.
3. ๋ง๊ฐ ๊ฒ์ดํธ
\[f_t = \sigma(W_{xf} x_t + W_{hf}h_{t-1})\]๋ง๊ฐ ๊ฒ์ดํธ์ ๊ฐ์ ํตํด cell state $c_{t-1}$์์ ์์ ์ ๋ณด์ ์์ ์ ํ๋ค.
๋ค์ cell state $c_t$์ ์์์ ์ดํด๋ณด์. ๋ง๊ฐ ๊ฒ์ดํธ $f_t$๋ฅผ ํตํด ์ด์ cell state $c_{t-1}$์์ ๊ธฐ์ตํ ์ ๋ณด๋ฅผ ๊ฒฐ์ ํ๊ณ , ์ ๋ ฅ ๊ฒ์ดํธ $i_t$์ ๊ฒ์ดํธ ๊ฒ์ดํธ $g_t$๋ฅผ ํตํด ํ์ฌ ์ ๋ ฅ์์์ ๊ธฐ์ตํ ์ ๋ณด๋ฅผ ๊ฒฐ์ ํ๋ค. ๐คฉ
\[c_t = f_t \circ c_{t-1} + i_t \circ g_t\]4. ์ถ๋ ฅ ๊ฒ์ดํธ
\[o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1})\]์ถ๋ ฅ ๊ฒ์ดํธ์ ๊ฐ์ <Cell>์ ์ถ๋ ฅ์ด ๋๋ฉฐ, ์ดํ ์ถ๋ ฅ์ธต์์ $y$์ ๊ฐ์ ๊ตฌํ๋๋ฐ ์ฌ์ฉ๋๋ค. ๋, ์ถ๋ ฅ ๊ฒ์ดํธ์ ๊ฐ์ ํตํด cell state $c_t$์์ ๋จ๊ธฐ์ ์ผ๋ก ๊ธฐ์ตํ ์ ๋ณด๋ฅผ ๊ฒฐ์ ํ๋ค.
\[h_t = o_t \circ \tanh(c_t)\]<LSTM>์ ๊ฒฝ์ฐ <RNN>์ ๋นํด ๊ตฌ์กฐ๊ฐ ์ ๋ง ๋ณต์กํ์ง๋ง, ์ด์ ๊ณผ ๋ฌ๋ฆฌ Gradient Exploding์ด๋ Gradient Vanishing ํ์์ด ๋๋๋ฌ์ง์ง ์๋๋ค! ๐ ์์ธํ ์ด์ ๋ฅผ ์๊ณ ์ถ๋ค๋ฉด, ์ด ์ํฐํด์ ์ฐธ๊ณ ํ๋ผ.
2014๋ ์๋ <LSTM>์ ๊ฐ์ ํ <GRU; Gated Recurrent Unit>๋ผ๋ sequential model์ด ์ ์๋์๋ค. <LSTM>์ฒ๋ผ ๊ฒ์ดํธ(gate)๊ฐ ๋ฌ๋ ค์์ง๋ง, $h_t$ ํ๋์ ์ํ๋ง์ ๊ธฐ์ตํ๊ธฐ ๋๋ฌธ์ <LSTM>๋ณด๋ค ๋ ๋น ๋ฅด๊ฒ ํ์ตํ๋ค๊ณ ํ๋ค. <GRU>์ ๋ํ ์์ธํ ๋ด์ฉ์ ์ด ์ํฐํด์ ์ฐธ๊ณ ํ๋ผ.
<RNN>, <LSTM>์ ์์ฐ์ด(Natural Language)์ ์๊ณ์ด ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๊ฐ์ฅ ๊ธฐ๋ณธ์ด ๋๋ ๋ชจ๋ธ์ด๋ค. ๊ทธ๋ฌ๋ ์ด๋ฒ ๋ด์ฉ์ ์ต์ํด์ง๋ ๊ฒ์ ์ถ์ฒํ๋ค.
โ related post
- NLP with PyTorch Cheat Sheet