Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Foundation models, now powering most of the exciting applications in deep learning, are almost universally based on the Transformer architecture and its core attention module. Many subquadratic-time architectures such as linear attention, gated convolution
arxiv.org
0. 핵심 요약
SSM의 parameter를 입력의 함수로 설정 → 정보를 전달할지, 잊어버릴지를 현재 token에 기반해 결정
recurrent mode의 hardware-aware parallel algorithm 설계
attention과 MLP 없이 Neural Networks와 selective SSMs 결합(이를 Mamba라고 함)
1. BackGround
최근 모델의 동향
최근 머신러닝의 paradigm은 매우 큰 데이터를 통해 거대 모델을 학습시키고, 이를 통해 여러 세부 task를 커버하는 방식
거대 모델로는 주로 Transformer 혹은 Attention으로 구성되어 있으며, 임의의 길이를 입력으로 받는다
Transformer(혹은 Attention)을 사용하는 이유는 해당 구조가 context window로 구성된 정보를 연결하는 데에 있어 뛰어난 성능을 보이기 때문
하지만, 해당 모델에도 한계점이 존재
window 크기를 제한이 있고, window length에 대해 quadratic scaling을 갖는다
이러한 문제를 해결하기 위해 다양한 방안이 제시되었지만, 이는 모두 계산 상의 효율을 위한 방안일 뿐 근본적인 해결책은 없었다
이러한 모델의 한계를 극복하기 위해 최근 SSMs이 주로 사용
SSMs의 장점은 계산이 매우 효율적이라는 것과 long-range dependencies에 매우 적합하다는 것
이러한 특징 때문에 audio난 vision과 같은 연속적인 signal data에서는 효과적
하지만, text와 같은 discrete data에서는 효과적이지 않음
해당 논문에서는 이러한 기존의 SSMs의 단점을 극복하고자 Selective SSMs을 제안
State Space Models
Mamba 모델은 SSMs을 토대로 생성
그중에서도 S4 모델을 많이 참고하였는데, 해당 모델에 대한 설명은 이전 논문 리뷰 참고
Efficiently Modeling Long Sequences with Structured State Spaces(S4) (2021) Review
Efficiently Modeling Long Sequences with Structured State Spaces A central goal of sequence modeling is designing a single principled model that can address sequence data across a range of modalities and tasks, particularly on long-range dependencies. Alth
rltjq09.tistory.com
2. 논문 핵심 내용
Motivation: Selection as a Means of Compression
Sequence Modeling의 근본적인 문제는 context를 small state로 압축하는 것
해당 문제는 tradeoff가 존재 → 모델의 성능이 향상되면 state를 많이 압축하지 못해 용량이 증가
해당 tradeoff는 synthetic tasks의 학습 예시를 통해 확인 가능
1. selective copying : token의 위치 기억 필수 → context-aware reasoning 필요
2. Induction Heads : context 내 학습 능력 평가 → 역시 context-aware reasoning 필요
위의 두 예시로 LTI(Linear Time Invariance) models의 실패를 알 수 있음
※ LTI : 시간에 따라 선형적으로 모델의 속도가 변하는 모델을 의미
In recurrent,
context에서 정보 선택을 하지 않기에, 입력을 그대로 전달하게 됨
In convolutional,
time-awareness가 필요한 vanilla copying task는 가능, 하지만 content-awareness인 selective copying task는 불가
결론적으로, sequence model에서 tradeoff를 해결하는 것은 state를 얼마나 잘 compress 하냐에 달렸다
Improving SSMs with Selection
parameter를 입력과 sequence와 상호작용하도록 설정하는 것을 selective mechanism 이라고 함
이를 기존 S4 모델에 적용하여 Mamba에서는 S6를 제안
위의 알고리즘 표를 보면 주황색으로 어디가 변화했는 지를 볼 수 있음
1. $\Delta,~ B,~ C$를 input에 대한 함수로 설정
2. L → time에 따른 길이
3. $S_B(x) = Linear_N(x), ~ S_C(x) = Linear_N(x), ~ S_{\Delta}(x) = Broadcast_D(Linear_1(x)), ~ \tau_{\Delta} = softplus$
4. $Linear_d$ → d차원으로 projection 한다는 의미
Efficient Implementation of Selective SSMs
이전 SSM 모델들의 계산 상의 단점
1. model의 state가 커질수록 속도 감소 → speed와 memory 소비 감소 & state는 최대로
2. x나 y보다 크게 latent state h가 필요 → state 계산은 bypass, $\bar{K}$만 유지
3. recurrent-convolutional forms로 state 최대화
이를 해결하기 위해 3가지 기술 도입 : kernel fusion, parallel scan, recomputation
이때 또 우리가 알아야 하는 2가지 사실이 있다
1. recurrent 계산은 $O(BLDN)$, convolutional 계산은 $O(BLDlog(L))$ → long sequences나 크지 않은 state N에서는 FLOPs가 적게 사용됨
2. 2가지 challenges는 재귀의 sequential과 large memory 사용 → state h 전부를 물질화하지 않아도 됨
따라서, 우리의 main idea는 state h를 memory에서 효율적인 정도만 물질화하는 것
행렬 곱을 제외한 대부분의 연산은 memory bandwidth 안에서 처리 가능
해당 연산에는 scan 연산 역시 포함
따라서, $(\bar{A},~ \bar{B})$를 GPU의 HBM(high-bandwidth memory)에서 scan 하지 않음
SSM 파라미터($\Delta, A, B, C$)를 HBM에서 SRAM으로 load한 뒤 SRAM에서 계산한 후 다시 HBM으로 전달
또한, backpropagation을 위해 중간의 state를 저장하지 않음
이는 input이 HBM에서 SRAM으로 이동하여 계산하는 동안 backward는 중간 state를 동시에 계산
3. Model Architecture
Selective SSMs는 neural network에 잘 적용하기 위해 standalone sequence transformations로 설계
기존 H3와 달리 stacked 구조로 변경
D 차원을 E로 확장, 각 block 별로 $3ED^2$을 $2ED^2$ 또는 $ED^2$로 projection
E=2로 설정(Transformer와 유사)
SiLU / Swish activation function 사용
LayerNorm 적용
4. Experiment
5. Contribution
구조화된 상태 공간 모델에 선택 메커니즘을 도입하여, 시퀀스 길이에 따라 선형으로 확장되면서 문맥 의존적 추론을 수행Mamba는 다양한 도메인에서 최고의 성과를 달성했으며, 강력한 Transformer 모델의 성능을 맞추거나 뛰어넘음
+ Comment
최근 급격하게 떠오르는 모델 구조이다
다양한 task에 적용하고 있는 만큼, 정말로 transformer를 뛰어넘을 수 있는 구조가 될 수 있을 지 궁금하다