David의 개발 이야기!

LSTM 에 대해 알아보자 본문

자연어처리

LSTM 에 대해 알아보자

david.kim2028 2023. 8. 19. 21:12
반응형

2023.08.19 - [자연어처리] - RNN 에 대해 알아보자

 

RNN 에 대해 알아보자

RNN(Recurrent Neural Network) 는 입력과 출력을 시퀀스 단위로 처리하는 시퀀스(Sequence) 모델이다. 1. RNN 예시 1. POS tagging RNN은 아래그림 같이, POS tagging(품사 분류)를 할 수 있다. 모델 구조를 좀 더 자세

david-kim2028.tistory.com

 


1. Vanilla RNN 의 한계

앞선 글에서 알수 있듯이, 바닐라 RNN 은 "Long-Term Dependencies" 라는 문제가 있다. 이러한 문제를 극복하기 위한 방법중 하나가 LSTM 이다. "Long-Term Depedencies"를 좀 더 자세히 설명하면, 아래 그림과 같다.

 

RNN 에서 BPTT를 사용하여, 하이퍼파리미터들을 수정하는데, 이 과정에서 다음과 같은 문제가 발생할 수 있다.

 

 

만약에, 극단적인 예시로, 곱해지는 값들이 1보다 작거나, 클때, gradient가 유의미하게 학습되지 않음을 볼 수 있다.

 

예를 들자면, 

"홍대에 놀러왔는데, 사람들도 많고, 맛있는게 많았어. 그런데, 갑자기 급한 전화가 왔어. 학교에서 도망갔는데, 어디갔냐고 묻더라고.. 그래서 저 여기 있는데요 __ " 라는 문장이 있다고 하자. 

 

다음 단어를 예측하기 위해서는 장소 정보가 필요한데, 장소 정보에 해당되는 단어인 "홍대"는 앞에 있고, RNN은 장기기억을 가지지 못하기 때문에 단어를 엉뚱하게 예측한다. 위 그림에서 볼수 있다시피, x1 에 수많은 값들이 곱해지며, 학습되므로, x1은 전체정보에 대한 영향력이 사라지는 것이다. 

 

2. LSTM(Long Short Term Memory)

 

위 그림은, LSTM 모델을 보여준다. LSTM 에는 vanilla RNN의 은닉층에, "forget gate layer", "input gate layer", "output gate layer"를 추가하여, 불필요한 기억을 지우고, 기억할 정보를 정해 vanilla RNN의 한계를 극복한다. Cell State(셀상태) 라는 값을 추가하여, 장기기억을 보존한다. 

 

LSTM 대략적인 구조는 아래 그림과 같다.

 

 

(1) forget gate layer 

 

 

forget gate layer 는 기억을 삭제하기 위한 게이트로, 현재시점 t의 x값(input r값)과 이전 시점 t-1의 은닉상태(ht-1)가 시그모이드 함수를 지나게 된다. 그림에서 오른쪽하단 그래프를 보면 알 수 있듯이, sigmoid 함수를 지나면, 0과 1사이의 값이 나오게 되는데, 이 값이 결국 삭제 과정을 거친 정보의 양을 의미한다. 1에 가까울수록 정보를 많이 기억하는 것이다. (과거의 정보를 얼마나 보존시킬지 결정하는거라 이해하면 된다.)

 

sigmoid 값을 지난 값이 0.2 가 나왔다면, Ct-1 곱해져 다음 스텝으로 넘어가게 된다.

 

 

 

(2) input gate layer  & Cell State

 

(it 가 sigmoid 이후 계산값, gt 가 tanh 이후 계산값)

 

input layer gate 는 현재 정보를 기억하기 위한 게이트다. 현재 시점 t의 x값(input 값 == jane) 과 input gate 로 이어지는 가중치 Wxi 를 곱한 값과, 이전 시점 t-1의 은닉상태(ht-1) 가 input gate로 이어지는 Whi를 곱해 Sigmoid 함수를 통과한다. 이를 it리 한다. 

 

그리고 현재 시점 t의 x값(input 값 == jane) 과 입력게이트로 이어지는 가중치 Wxg를 곱한 값과 이전 시점 t-1 의 은닉 상태가 input gate로 이어지는 가중치 Whg를 곱한 값을 더해 tanh 를 통과한다. 이를 gt 라고 한다. 

 

Sigmoid 를 지나, 0과 1사이의 값을 가지는 it와 tanh 함수를 지나 -1과 1사이의 값을 가지는 gt 이 두가지를 곱해 기억할 정보의 양을 결정한다. 

 

Cell State Ct를 구하는 방법은 다음과 같다.

input gate에서 구한 it, gt 이 두개의 값에 대해 원소별 곱(entrywise product)을 진행한다. ( entrywise product == 같은 크기의 두 행렬이 있을때 같은 위치의 성분끼리 곱하는 것을 의미한다. )

 

input gate에서 선택된 기억을 forget gate의 결과값을 더하고, 이 값을 현재시점 t의 Cell State라고 하며, 이 값은 t + 1시점의 LSTM셀로 넘겨진다. 

 

(3) Output Gate

Cell State

 

 

Memory Cell 에 있는 정보가 tanh를 통해 들어오고, hidden State의 현재정보가 Sigmoid를 통해 들어와, 곱해져, 출력이되고, 다음 셀로 넘겨진다. 

 

 

코드 실습은 깃헙에 있습니다!

https://github.com/Kdavid2355/ai_code/blob/main/LSTM_Predict_SIne_Function.ipynb

 

 

 

 

출처

https://www.youtube.com/watch?v=bX6GLbpw-A4&t=3s

https://colah.github.io/posts/2015-08-Understanding-LSTMs/

반응형
Comments