본문 바로가기
논문 리뷰

Efficiently Modeling Long Sequences with Structured State Spaces(S4) (2021) Review

by rltjq09 2024. 3. 24.
728x90
 

Efficiently Modeling Long Sequences with Structured State Spaces

A central goal of sequence modeling is designing a single principled model that can address sequence data across a range of modalities and tasks, particularly on long-range dependencies. Although conventional models including RNNs, CNNs, and Transformers h

arxiv.org

0. 핵심 요약

SSM(State Space Model)의 이론적인 강점은 유지하고 더 효율적으로 계산할 수 있는 모델인 S4 제안

이는 A 행렬을 저차원으로 조절하고, 대각화를 안정적으로 수행하며, SSM이 Cauchy Kernel에서의 계산량이 줄어들도록 함

1. BackGround

기존 Sequence modeling의 기조

Sequence modeling에서 주요한 문제는 long-range dependencies를 어떻게 처리하냐이다

대부분의 모델들이 이를 평가하는 LRA(Long Range Arena)에서 매우 낮은 성능을 기록중

이를 해결하기 위해서 다양한 시도들이 있었으나, 여전히 좋은 성능을 보이지는 못함

 

최근에 이를 해결하기 위해 SSM을 딥러닝에 도입하는 시도를 하였음

하지만 이론적인 이유로 이를 딥러닝에 바로 적용할 수는 없었다

해당 이유뿐만 아니라 다른 논문에서도 SSM을 적용해보았지만 성능이 향상되지는 않았다

하지만 단순히 적용하는 것이 아닌 특별한 state matrices A를 추가하면 해당 문제를 해결 가능

기존의 다른 논문인 Linear State Space Layer(LSSL)로 인해 SSM의 가능성이 높아졌다

 

LSSL은 RNN이나 CNN보다 memory는 많이 사용하고, 이론적으로는 탄탄하지만 불안정하다는 단점 존재

 

S4는 이전 연구들의 단점을 해결함과 동시에 computational bottleneck을 해결한 안정적인 모델

 

State Space Models

state space model이란 1-D input signal $u(t)$을 N-D latent space $x(t)$로 mapping

그런 뒤, output인 1-D signal $y(t)$로 projection을 수행하는 모델

수식으로 표현하면 다음과 같다

$x'(t) = Ax(t) + Bu(t),~~~y(t) = Cx(t) + Du(t)$

 

이때, 해당 논문에서는 D를 주로 생략하거나 0으로 처리

계산이 쉽기 때문에 이를 생략하였다고 함

 

해당 논문에서의 목표는 위의 A, B, C, D를 gradient descent로 학습 가능하도록 만드는 것

 

HiPPO(Hidden Paths to Pertant Objects)

기존 SSM의 문제점으로는 linear first-order ODE가 지수함수여서 gradient가 지수 범위로 표현

이는 기울기가 explode 하거나 vanish 하는 문제가 발생

(※ ODE(ordinary differential equation) : 상미분방정식)

 

LSSL은 이를 해결하기 위해 HiPPO를 활용

HiPPO는 앞선 식의 A의 class를 지정, input $u(t)$의 이력을 기억하기 위해 $x(t)$를 사용

A는 아래의 식으로 정의된다

$\mathrm{(HiPPO~Matrix)} ~ A_{nk} = - \begin{cases} (2n+1)^{1/2} (2k+1)^{1/2} & n > k \\ n+1 & n=k \\ 0 & n < k  \end{cases}$

 

Discrete-time SSM

$u(t)$는 연속, 입력은 이산($u_0$, $u_1$, $\cdots$)따라서, 이를 해결하기 위해 앞선 SSM 식을 step size $\Delta$로 이산화

$u_k$는 $u(t)$에서 sampling한 것으로 볼 수 있음

따라서 이는 $u_k = u(k\Delta)$와 같다

 

SSM 이산화를 위해 A를 approximation A로 변환

$x_k = \bar{A}x_{k-1} + \bar{B}u_k,~~  y_k = \bar{C}x_k$

$\bar{A} = (I - \Delta / 2 \cdot A)^{-1} (T + \Delta / 2 \cdot A), ~~ \bar{B} = (I - \Delta / 2 \cdot A)^{-1} \Delta B, ~~ \bar{C} = C$

위의 수식이 Discrete SSM을 나타내며, recurrent SSM이라고도 함

 

Training SSM

앞서 설명한 Discrete SSM은 sequntiality 때문에 학습에 적합하지 않는다

따라서, 이를 학습에 적합하도록 변형하기 위해 LTI(Linear Time-Invariant) SSM과 Continuous convolution과의 관계를 활용

$x_{-1} = 0$이라 가정하면, 앞선 식을 다음과 같이 표현할 수 있다

$x_0 = \bar{B}u_0, ~~ x_1 = \bar{AB}u_0 + \bar{B}u_1, ~~ x_2 = \bar{A}^2\bar{B}u_0 + \bar{AB}u_1 + \bar{B}u_2, \cdots$

$y_0 = \bar{CB}u_0, ~~ y_1 = \bar{CAB}u_0 + \bar{CB}u_1, ~~ y_2 = \bar{CA}^2\bar{B}u_0 + \bar{CAB}u_1 + \bar{CB}u_2, \cdots$

위의 규칙을 통해 아래와 같은 식을 도출해낼 수 있다

$y_k = \bar{CA}^k\bar{B}u_0 + \bar{CA}^{k-1}\bar{B}u_1 + \cdots + \bar{CAB}u_{k-1} + \bar{CB}u_k$

$y = \bar{K} * u$

 

이를 다시 말하면, 단일 convolution이 된다는 것이고, 이는 FFT(Fast Fourier Transform)으로 계산 가능하다

 

$\bar{K} \in R^L := K_L(\bar{A},~  \bar{B},~  \bar{C}) := (\bar{CA}^i\bar{B})_{i \in [L]} = (\bar{CB}, ~\bar{CAB}, \cdots, ~\bar{CA}^{L-1}\bar{B})$

 

우리는 위의 K만 구하면 되는 것이고, 하지만 이를 직접적으로 계산하는 것은 쉽지 않다

해당 논문에서의 주된 내용이 바로 해당 K의 연산을 쉽게 변환하는 것이며 이를 SSM convolution kernel이라 함

 

 

해당 이미지는 SSM 모델의 변형 과정을 보여준다

기존의 연속적이였던 SSM을 Discrete 하게 변형하여 Convolution과 recurrent에 적용 가능하도록 함

2. 논문 핵심 내용

Diagonalization

앞선 discrete-time SSM 계산의 주된 bottleneck은 바로 $\bar{A}$의 중복 행렬 곱

따라서, 기존 행렬을 다음과 같이 변환하여 이를 해결하고자 함

$(A,~B,~C) \sim (V^{-1}AV, ~V^{-1}B,~CV)$

A를 canonical form으로 변환함으로써 계산 효율을 향상시킴

하지만 이처럼 대각화가 항상 가능한 것은 아니다

즉, 특정한 조건을 갖춘 V 만이 대각화(혹은 conjugate)가 가능하다는 것

 

Normal Plus Low-Rank

이전에 조건이 갖춰진 V만 conjugate가 가능하다 언급

HiPPO Matrix는 불행히도 해당 조건을 만족하지 않음

(+HiPPO Matrix는 sum of a normal & low-rank matrix로 변환은 가능하다 → 여전히 비효율적)

 

따라서, HiPPO Matrix가 조건을 만족하도록 아래 3가지 계산을 통해 해결

1. $\bar{K}$를 직접 계산하지 않고 truncated generating function을 활용해 우회적으로 계산

truncated generating function은 $\sum^{L-1}_{j=0}\bar{K}_j \zeta^i$

이를 통해 $\bar{K}$의 spectrum을 계산

 

2. generating function은 행렬 분해와 연관

따라서, 기존의 행렬이 power를 가지고 있었다면 이 대신 inverse를 포함하도록 변환

이를 통해 low-rank term이 woodbury identity가 됨

 

3. 최종적으로 Cauchy kernel 계산과 동일해짐

 

이를 통해, 어떤 matrix도 Normal Plus Low-Rank로 변환이 가능해짐

최종적인 수식을 표현하면 다음과 같다

 

$A = V\Lambda V^* - PQ^T = V(\Lambda - (V^*P)(V^*Q)^*)V^*$

위에서 *가 있는 행렬은 conjugate 행렬을 의미

 

 

위 방법들을 활용해 $\bar{K}$를 구하는 과정이다

 

S4 Algorithms and Computational Complexity

위의 수식에서 우리는 NPLR을 DPLR(Diagonal Plus Low-Rank)로 변환 가능하다는 것을 알 수 있음

이를 통해 총 2가지 사실을 알 수 있다

1. $\Delta$가 주어질 때, S4는 O(N)의 복잡도를 가진다

2. $\Delta$가 주어질 때, $\bar{K}$는 4 Cauchy multiplies 가 감소하고, 따라서 $\tilde{O}(N+L)$ 복잡도를 가진다

($\tilde{O}$는 soft 복잡도를 의미)

 

3. Model Architecture

S4 layer의 파라미터 구성은 다음과 같다

1. SSM을 HiPPO Matrix로 설정한 A로 초기화

2. 앞선 내용들을 바탕으로 $(A, ~B, ~C) \Rightarrow (\Lambda - PQ^*, ~B, ~C)$로 변환

이때, $\Lambda$는 대각행렬, P,Q,B,C는 $\mathbb{C}^{N\times 1}$ 인 벡터

총 5N의 trainable 파라미터로 구성

 

S4는 $\mathbb{R}^L$을 $\mathbb{R}^L$로 mapping 하는 모델

이때 입력은 1-D Sequence map

H차원의 data를 처리하기 위해 H features를 position-wise로 linear layer를 적용

 

layer 사이에 비선형 함수를 추가

 

다른 모델들과 동일하게 sequence-to-sequence map 정의 (batch size, sequence length, hidden dimension)

 

S4 module의 핵심은 선형 변환, 하지만 비선형 변환을 추가하여 deep SSM을 비선형화

또한, S4 model은 depthwise-convolution CNN과 유사(global convolution kernel을 추가한)

 

4. Experiment

S4 Efficiency Benchmarks

S4는 long-range sequence modeling에서 매우 빠르고 효율적

 

S4는 다른 모델과 비교했을 때 속도가 빠르면서 동시에 memory도 매우 효율적

 

Learning Long Range Dependencies

1. LRA(Long Range Arena)

다른 모델과 비교했을 때 매우 뛰어난 성능을 보이는 것을 확인할 수 있음

특히, Path-X task를 유일하게 성공해낸 모델

 

2. Raw Speech Classification

SC10 Dataset(Speech Commands)를 활용해 성능을 평가하였을 때, 매우 우수한 성능을 보임

특히, 음성의 길이가 긴 음성에서 더욱 큰 성능을 보였음

 

 

S4 as a General Sequence Model

S4가 domain에 상관없이 성능이 좋다는 것을 입증

왼쪽부터 1-D image Classification, CIFAR-10 density estimation, WiKi Text 103 language modeling에 대한 결과

 

5. Contribution

S4는 state space model의 연속시간, 순환 및 컨볼루션 관점에 대한 새로운 매개변수화를 사용하여 장거리 의존성(LRDs)을 원칙적으로 효율적으로 모델링하는 시퀀스 모델 제안

 

+ Comment

해당 모델은 현재 Transformer를 대체할 가능성이 있다는 Mamba의 기초 논문이 되는 논문이다

state space model과 각종 선형대수 내용이 많아 이해하는 데 많은 어려움이 있었지만 읽으면서 재미도 있었던 논문이였다

728x90