https://arxiv.org/pdf/2002.10941
Attention mechanism 은 transformer 모델 기반의 핵심 연산으로, 자연어 처리와 컴퓨터 비전 영역에서 다용되고 있다. 이 attention mechanism은 inference 단계에서 높은 연산 비율을 차지하고 있음을 background에서 밝히며, 이 문제를 해결하고자 hw, sw co-design 설계 방법을 논문에서 보여준다.
먼저. attention mechanism을 수도 코드로 표현하였는데 이 수도 코드를 그림으로 크게 그린다면, 아래 그림과 같다.
attention mechanism의 연산량 비용을 논문에서는 소개하며 이 연산량은 n(시퀀스 길이), d(임베딩 차원)으로 의존하고 있음을 보여준다. 중요한 점은, 해당 논문에서는 d보다 n이 클수록 모델이 강력함을 시사한다.
n은 외부 지식 베이스에 있는 데이터 개수로, 얼마나 많은 데이터를 참고하는지를 보여주는 지표이고, d는 하나의 벡터 항목의 차원 수로 각각의 데이터를(토큰)을 얼마나 정교하게 표현하는가를 나타낸다. 이 값 중, n이 커질수록 모델이 더 많은 데이터 중에서 유용한 정보 탐색이 가능해짐을 주장하며, 논문에서 n의 값을 키우는데 집중한다. (hw design에서도 n=320, d=64로 고정하여 실험하는 것을 볼수 있음)
이 막대 그래프는 attention mechanism이 실제 모델의 연산 과정에서 얼마나 많은 비중을 차지하는가를 보여주는데, 전체 inference time과 question-answering time 두 가지 지표로 구분하여 측정을 하였다. (이때 사용한 모델은 memn2n, kvmemn2n, bert 총 3가지 모델로 진행, hw는 cpu, gpu-bert only) 붉은 박스를 보면 전체적인 attention mechanism이 차지하는 비율은 높아졌는데 반면 파란 박스는 비율이 유지되고 있음을 막대 그래프를 통해 볼 수 있다. 논문에서는 이 점을 시사하길. 두 모델의 inference 중 comprehension과 question response time의 차이 때문이라고 주장한다.
comprehension은 모델이 제공된 데이터를 이해하고 처리하는 시간으로, wikimovies task에서 kv-memn2n은 영화 관련 정보를 미리 읽고 이해하여 벡터 표현 과정 같은 것이 예시가 된다. question response time은 사용자의 질의에 모델이 응답을 생성하는데 걸리는 시간이다. 즉 memn2n(kvmem n2n)은 inference 과정에서 comprehension -> question response time이 순차적으로 진행되지만, BERT는 질문이 들어올 때, 처리해야할 사전 정보 또한 동시에 진행되므로 Inference가 시작될 때 comprehension과 question response time이 동시에 시작된다. 따라서 BERT는 inference time의 시간과 question-answering time이 동일하여 비율의 변화가 없음을 밝힌다.
결론적으로 그래프에서 각 모델의 workload에서 attention이 차지하는 비율은 최소 약 30프로 최대 약 80프로고 최적화의 필요성을 보여준다. 또한 기존 attention mechanism의 기존 구현 법들은 tensorflow나 pytorch framework로 최적화를 하였고 이들은 attention mechanism의 본질인 "검색"을 이용하지 않는다고 주장한다. attention mechanism의 검색이란, 질의에 맞는 key를 찾는 과정들이며 이 과정에서 softmax이후 0에 가까워져 최종 출력에 영향이 없는 점, 즉 계산 낭비 문제점을 지적한다.
따라서, 논문은 불필요한 연산을 피하는 "근사"방법을 제시한다. 방법은 key matrix를 사전에 전처리 (크기 순으로 나열)하여 높은 QK값을 가질 수 있는 후보(Candidate)를 선별하는 방법들이다. 이 방법의 기대는 내적, softmax, 최종 연산량을 줄일 것임을 기대할 수 있다.
우선 approximation을 반영한 attention은 크게 두 단계로 구분되며 각 단계는 pre-process, post-scoring으로 분리된다. 우선 approximation이란 다음 그림과 같이 필자는 이해했다.
Approximation
기존 attention mechanism의 vector-matrix multiplication 단계 중 softmax단계를 지나 최종적으로 결과 값에 영향을 주는 가중치들은 한정적일 것이다. (아래 그림에서 weight의 붉은 값들이 유의미한 값, 회색 값들은 무의미한 값) 즉 실제로 기여도가 높은 값들을 사전에 선별하는 것이 목표이다. 이를 논문에서는 candidate selection이라고 지칭한다.
*필자는 candidate를 greedy score 의 최종 합산 직전, 이 값의 후보로 들어갈 수 있는 element라고 이해를 했다. 틀릴 수 있다!
(approximation attention mechanism의 pre-process를 소개할 때 논문에서는 base greedy candidate search, efficient greedy candidate search 두 단계로 소개하며 이전 단계의 문제점을 개선하여 efficient가 어떻게 개선되었는지를 이해하는 것이 포인트였던 것 같다.)
우선 approximation attention의 Base Greedy candidate search에 대해 이해해보자.
Pre-process : Base Greedy Candidate Search
Base Greedy Candidate Search는 아래 그림과 같다.
과정을 크게 본다면, 먼저 Query vector를 key matrix의 n행 수 만큼 복제해서 matrix로 만든다. 두 번째로 key matrix와 query를 element wise multiplication (항들끼리 곱함)을 하여 result를 생선한다. 다음으로 Greedy Approximation을 진행하는데 이 단계의 의미는 다음과 같다. result table에서 가장 큰 값과 가장 작은 값을 먼저 찾아 낸다.(iteration 1st) 이렇게 찾아낸 최대, 최소 값들은 각 element가 속한 행의 번호에 맞게 greedy score에 저장한다. 또 다시 이 과정을 반복하여 (iteration 2nd) 그 다음으로 가장 큰, 작은 값들을 찾아주고 또한 이 과정을 반복한다. (iteration 3rd) 이렇게 찾아낸 최대값 3개, 최소값 3개들은 greedy score로 각 행에 맞게 값이 합산되어 있다. 이 값들은 실제 result table을 행 단위로 합산했을 때의 값과 유사함을 위 그림에서 보여준다. 이 단계의 논리는 다음과 같다. result table에서 우리는 행 단위로 있는 원소들을 합산할탠데 이 단계에서 가장 큰 값과 가장 작은 값들을 합할 때 나오는 값은 true score랑 유사할 것임이 아이디어이다. 이 점을 반영한 것을 greedy score라고 생각한다. 이때 음수로 나온 greedy score은 다음 단계의 softmax연산을 위해 (불필요한 연산, 즉 exp(x)의 x값이 음수여서 0에 가까워지는 값을 만들어내는 요소) 배제한다. 최종 greedy approximation의 반복횟수 M은 hyperparameter로 사용자가 조절 가능하다.
하지만 아직 이 과정은 비효율적임을 논문에서 주장한다. 기존의 QK과정 (위 그림에서 STEP2까지)의 연산량은 O(nd) (element-wise multiplication으로 인해)이지만, base greedy candidate search 는 유의미한 값을 찾기 위한 과정은 iteration M이 추가되기 때문에 연산량이 O(nd * lognd)로 lognd배만큼 크다. 따라서 이 비효율을 개선해야 한다.
Pre-process : Efficient Greedy Candidate Search
위 base greedy candidate search를 개선한 알고리즘 아이디어가 이와 같다. 논문에서는 figure 7, 8로 수도코드와 도표로 표현하여 설명하였는데 필자의 가내 수공업으로 이 내용을 정리해보았다..
단계는 다음과 같이 구성된다.
1. query가 들어오기 전
[step1]
key matrix의 값들을 우선 열 단위로 큰 값부터 작은 값들로 정렬한다. 이 정렬한 matrix는 processed key로 지칭한다.
2. query가 들어오고 나서
[step2]
query가 들어오고 나서 query의 각 원소 값들의 음과 양을 확인한다. query vector의 열에서 음의 값을 가진 항이 확인되면 key matrix에서 가장 큰 음의 값을 추출해주고, 반대로 양의 값을 가진 항이 확인되면 key matrix에서 가장 큰 양의 값을 추출해준다. 이는 query * key를 할 때 같은 부호끼리 곱해줄 때 큰 값을 도출한다는 아이디어이다. (*주의 : 해당 과정은 greedy score의 max값만 찾을 때 서술하는 내용이므로 min의 과정은 생략되어 있음 또한 min은 방금 이 순서에서 부호가 반대가 되는 key를 찾아줘야 함)
[step3]
processed key에서 가장 큰/작은 값들을 찾아주고 각 값들의 실제 key matrix에서 행 번호 또한 max_ptr로 저장해준다.
이렇게 찾은 가장 큰 값/작은 값 key를 query와 곱을 하여 value를 찾아준다. 이렇게 찾은 값들 중 가장 큰 값만 남겨 greedy score에 더해준다.
step1 ~ step 3 M번 반복
이전 과정들을 M번 반복해준다.
이 반복을 통해 efficient greedy 는 복잡도가 O(M lognd)로 개선된다. 이 과정을 통해 동일한 의문을 가졌을 것 같은데 hyperparameter M으로 인해 trade-off가 있다. M의 값이 커질수록 정확도는 유지되고 반대로 작아진다면 accuracy drop이 있을 것임이 예상된다. 하지만 evaluation 단계에서는 그 정확도 감소는 적음을 보여준다.
Post-scoring
5. 피드백
오타 지적은 매우 환영
잘못된 정보 피드백은 더욱 환영..