본문 바로가기
논문 리뷰

BitNet: Scaling 1-bit Transformers forLarge Language Models (2023) Review

by rltjq09 2024. 3. 13.
728x90
 

BitNet: Scaling 1-bit Transformers for Large Language Models

The increasing size of large language models has posed challenges for deployment and raised concerns about environmental impact due to high energy consumption. In this work, we introduce BitNet, a scalable and stable 1-bit Transformer architecture designed

arxiv.org

0. 핵심 요약

LLM을 위한 1-bit Transformer 제안

1-bit weight 학습을 위해 nn.Linear를 대체하는 BitLinear 제안

1. Background

기존 LLM 모델 경량화 기조

현재 LLM 모델은 꾸준히 성장하고 있지만, 이와 비례해서 LLM 모델을 사용하기 위한 비용과 energy computation 역시 같이 증가하고 있다

이는 LLM 을 배포할 때 큰 문제로 작용한다

따라서, model quantization에 대한 필요성이 증가하였다

 

기존에는 이를 위해 post-training 으로 해결하고자 하였다

post-training의 한 종류로써 fine-tuning이 존재하며, 이는 현재 많은 곳에서 실제로 활용되고 있는 기법 중 하나이다

하지만, 이러한 방식은 정확도가 떨어진다는 단점이 존재한다

 

또 다른 방법으로는 quantization-aware training이 존재한다

해당 방법은 post-training보다 정확도는 높지만 학습 과정에서 수렴이 어렵다는 단점이 존재한다

 

따라서, 해당 논문에서는 model compression을 위해 binarization에 집중하였다

이전 연구들에서는 이를 단순히 model translation이나 BERT에만 활용했지만, 해당 연구에서는 직접적으로 Transformer에 적용하였다

 

BitNet에서 변하지 않은 부분

BitNet은 기존 Transformer에서 Linear 연산만 변화하고 나머지는 변하지 않았다

이렇게 한 이유에 대해 다음과 같은 3가지를 들고 있다

1. residual connection과 layer normalization은 다른 연산에 비교하면 무시해도 될 만한 cost이다

2. QKV transformation은 모델 용량이 커질수록 parametric projection보다 더 효율적이다

3. input / output embedding은 그대로 유지, 언어 모델은 sampling 시에 높은 정확도를 요구하기 때문이다

 

Dataset

학습에 활용된 데이터셋은 총 4개이며, 모두 영어로만 학습되었다

1. Pile Dataset

2. Common Crawl snapshots

3. RealNews

4. CC-Stories Dataset

 

모든 데이터는 sentencpiece tokenizer로 토큰화를 진행하였다

 

추가로 참고할만한 논문

아래 핵심 내용을 읽다 보면 해당 논문에서는 설명하지 않고 넘어가는 것이 다수 존재한다

이때, 해당 논문들을 참고하면 많은 도움이 될 것이다

 

1. Absmax quantization

 

GPT3.int8(): 8-bit Matrix Multiplication for Transformers at Scale

Requests for name changes in the electronic proceedings will be accepted with no questions asked. However name changes may cause bibliographic tracking issues. Authors are asked to consider this carefully and discuss it with their co-authors prior to reque

proceedings.neurips.cc

2. SubLN

 

Foundation Transformers

A big convergence of model architectures across language, vision, speech, and multimodal is emerging. However, under the same name "Transformers", the above areas use different implementations for better performance, e.g., Post-LayerNorm for BERT, and Pre-

arxiv.org

3. model parallelism

 

Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism

Recent work in language modeling demonstrates that training large transformer models advances the state of the art in Natural Language Processing applications. However, very large models can be quite difficult to train due to memory constraints. In this wo

arxiv.org

4. Straight-through estimator

 

Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation

Stochastic neurons and hard non-linearities can be useful for a number of reasons in deep learning models, but in many cases they pose a challenging problem: how to estimate the gradient of a loss function with respect to the input of such stochastic or no

arxiv.org

5. Mixed precision training

 

How Do Adam and Training Strategies Help BNNs Optimization

The best performing Binary Neural Networks (BNNs) are usually attained using Adam optimization and its multi-step training variants. However, to the best of our knowledge, few studies explore the f...

proceedings.mlr.press

6. Arithmetric operation energy

 

CVPR 2022 Open Access Repository

PokeBNN: A Binary Pursuit of Lightweight Accuracy Yichi Zhang, Zhiru Zhang, Lukasz Lew; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2022, pp. 12475-12485 Abstract Optimization of Top-1 ImageNet promotes enormou

openaccess.thecvf.com

7. Scaling Law

 

Scaling Laws for Neural Language Models

We study empirical scaling laws for language model performance on the cross-entropy loss. The loss scales as a power-law with model size, dataset size, and the amount of compute used for training, with some trends spanning more than seven orders of magnitu

arxiv.org

 

2. 논문 핵심 내용

BitLinear

 

BitLinear를 그림으로 표현하면 위와 같다

 

우선 1-bit weights를 생성하는 과정이다

1. 원래의 가중치를 평균이 0이 되도록 centralize 수행 (이는 용량을 더욱 높히기 위한 과정이다)

2. signum function을 통해 weight를 1 또는 -1로 이진화

 

위의 과정을 수식으로 표현하면 다음과 같다

$\tilde{W} = sign(W - \alpha)$

$sign(W_{ij}) = \begin{cases} 1 & > 0 \\ -1 & \leq 0 \end{cases}$

$\alpha = {1 \over nm} \sum_{ij}W_{ij}, ~ W \in R^{n \times m}$

 

다음으로 설명할 부분은 Absmax Quantization이다

해당 부분은 activations을 quantization 하는 것이다

우선 수식을 보자면 다음과 같다

$\tilde{x} = Quant(x) = Clip(x \times {Q_b \over \gamma}, -Q_b + \epsilon, Q_b - \epsilon)$

$Clip(x, a, b) = max(a, min(b, x))$

$\gamma = ||x||_{infty}$

 

$\gamma$에 대해서 먼저 설명하자면, 해당 수식은 입력 데이터 중에서 크기가 가장 큰 데이터를 골라내는 것이다

이를 그대로 위의 식에 적용하면 입력 데이터마다 최대값을 나눠준다는 의미가 된다

다음으로 $Q_b = 2^{b-1}$는 우리가 scaling 하고자 하는 범위를 나타낸다 (여기서 b는 비트수를 의미)

 $\epsilon$은 overflow를 방지하기 위한 값으로 매우 작은 값으로 설정한다

 

위의 수식을 논문에서는 ReLU 에 적용을 하였는데, 과정은 다음과 같다

우선 ReLU는 음수가 없기 때문에 $[0, Q_b]$로 스케일링을 수행하고자 한다

따라서, 수식은 다음과 같다

$\tilde{x} = Clip((x - \eta) \times {Q_b \over \gamma}, -Q_b + \epsilon, Q_b - \epsilon)$

$\eta = \underset{ij}min ~x_{ij}$

이를 수행하여 활성화 함수를 8-bit로 변환한다

 

이러한 양자화 과정은 학습 시에는 매 tensor 마다, inference 시에는 매 token 마다 수행된다

따라서, 이는 $y = \tilde{W} \tilde{x}$ 로 표현할 수 있다

이때, W와 x는 서로 독립이면서 동시에 동일한 분포를 따른다고 가정하면 y의 분산은 다음과 같이 추정할 수 있다

 

만약, full-precision computation일 때는 y의 분산이 일반적인 파라미터 초기화 기법처럼 1에 근사한다

따라서 이를 양자화한 이후에도 최대한 유지하는 것이 중요한데, 이를 위해 activation quantization 이전에 LayerNorm을 적용한다

$Var(y) \approx E[LN(\tilde{x}^2] = 1$

위의 수식을 보면 이 역시 1에 근사하는 것을 볼 수 있다

이렇게 분산을 1로 근사하도록 LayerNorm을 적용하는 것이 SubLN이다

 

최종적으로 SubLN과 앞서 설명한 Quanization을 결합하면 이것이 BitLinear가 되는 것이며 수식은 다음과 같다

$y = \tilde{W}\tilde{x} = \tilde{W} Quant(LN(x)) \times {\beta \gamma \over Q_b}$

$LN(x) = {x - E(x) \over \sqrt{Var(x) + \epsilon}}, ~ \beta = {1 \over nm} ||W||_1$

 

따라서 전체 과정을 다시 설명하면

1. 입력된 데이터에 대해서 LayerNorm을 수행한다

2. absmax quantization을 수행함으로써 $\gamma$를 구하고 비선형 함수를 적용한다

3. signum function을 통해 가중치를 이진화하고, $\beta$를 구한다

4. 이진화된 가중치와 비선형 함수를 적용한 행렬을 행렬 곱한 뒤, 해당 결과에 $\gamma$와 $\beta$를 적용하여 최종 출력

 

Model Parallelism

LLM 모델의 크기를 키우기 위해서는 model parallelism이 매우 중요하다

이때 model parallelism을 하는 방법을 간단하게 설명하면 모델이 수행하는 여러 행렬 곱들을 여러 divice로 분할하여 연산하는 것이다

이를 위해서는 tensor들이 모두 독립이어야 하는데, 위의 연산을 보면 여러 파라미터들이 여러 tensor에 걸쳐 계산되어 독립성을 위반한다

또한, SubLN 역시 독립성을 위반하는 연산이다

따라서, 이를 우회적으로 계산하기 위해 weight와 activation을 모두 group화를 진행, 각 group 마다 연산을 진행한다

이를 group quantization이라고 하며, group을 나누는 기준은 데이터의 행을 기준으로 분할한다

 

LN 역시 group normalization을 도입하여 model parallelism을 사용할 수 있도록 하였다

3. Model Architecture

해당 이미지는 위에서 설명한 BitLinear가 Transformer에서 어떻게 쓰이는 지를 도식화한 그림이다

BitLinear를 제외하면 기존의 Transformer와 모두 동일하기 때문에 모델 구조에 있어서는 어려움이 없다

 

Model Training

모델을 학습하는 데에 있어서 총 3가지를 강조하고 있다

1. Straight-through estimator

이는 간단히 설명해 sign이나 clip과 같이 미분이 불가능한 연산들을 gradient 연산을 할 때 자연스럽게 우회할 수 있도록 해주는 기법이다

자세한 사항은 앞서 소개한 논문을 참고하기 바란다

 

2. Mixed precision training

모델 파라미터의 비트 수를 낮추게 되면 그만큼 정보 손실이 일어나게 된다

이러한 정보 손실을 최소화하기 위해 학습 과정에서의 기울기와 최적화 상태는 기존의 정보 그대로 유지를 한다

그런 뒤 이러한 값을 학습에 그대로 활용하면서 최대한 정보를 활용하고자 하는 방식이다

이는 학습에서만 활용되며, inference 시에는 사용하지 않는다

더욱 자세한 내용은 역시 소개한 논문을 참고하기 바란다

 

3. Large learning rate

해당 모델은 파라미터가 1 혹은 -1로 존재한다

따라서 학습률이 너무 작으면 파라미터가 전혀 바뀌지 않는 문제가 발생한다

따라서 학습률을 크게 조정하여 적절한 학습이 이루어지도록 한다

 

Computational Efficiency

해당 논문에서는 arithmetric operation energy와 memory footprint로 계산 효율을 계산한다

LLM의 주요 계산은 행렬 곱 연산이기 때문에 이에 집중하여 계산하였다

 

BitNet과 Transformer의 행렬 곱 연산에서의 에너지 차이를 보면 다음과 같다

기존 Transformer

$E_{mul} = m \times n \times p \times \hat{E}_{mul}$

BitNet

$E_{mul} = (m \times p + m \times n) \times \hat{E}_{mul}$

 

이렇게 기존의 곱하기 연산을 더하기로 바꿀 수 있었던 이유는 바로 행렬에서 계산한 파라미터로 행렬을 계산하기 때문이다

energy를 계산한 표이며, 자세한 수치에 대한 해석은 생략하겠다

 

Inference-optimal scaling Law

Transformer로 구성된 언어 모델은 계산량을 통해 손실을 계산할 수 있어 최적의 자원 할당이 가능하다고 한다

BitNet 역시 해당 법칙을 따르는지 확인해보기 위해 token 수는 고정하고 모델의 크기를 변화해가며 loss 그래프를 확인

위의 그래프를 보면 기존의 Transformer와 유사한 경향을 보이기 때문에 법칙을 따른다고 봐도 무방하다

따라서, 이를 통해 BitNet의 모델 크기를 125M ~ 6.7B 로 결정하였다

 

기존의 모델들은 FLOPs를 통해 모델의 계산 비용을 추정하였다

하지만, 이는 정수 계산을 통해 추정하기 때문에 1비트로만 연산하는 BitNet에는 적절하지 않다

따라서 energy computation 대비 loss를 계산하는 Inference-optimal scaling law 라는 방법을 제안

위의 그래프를 보면 동일한 energy 대비 BitNet이 더 loss가 낮은 것을 볼 수 있다

 

4. Experiment

Downstream tasks

Hellaswag, Winogrande, Winograd, Storycloze 총 4개의 downstream task에서 평가를 진행하였으며, 0-shot과 4-shot 을 평가하였다

왼쪽이 0-shot, 오른쪽이 few-shot 의 결과이다

모두 BitNet이 energy 대비 정확도가 높은 것을 볼 수 있다

 

Stability test

경량화 모델들은 정확도만큼 모델의 안정성 역시 중요하다

이를 확인하기 위해 모델의 학습 시 수렴되는 양상을 확인하였다

BitNet이 수렴이 매우 안정적으로 되는 것을 확인할 수 있다

 

Other Experiment

 

Transformer 뿐만 아니라 다른 모델들과도 비교를 해보았을 때, 정확도가 뛰어난 것을 볼 수 있다

 

5. Contribution

BitNet은 대규모 언어 모델을 위한 새로운 1비트 트랜스포머 아키텍처

효율적인 처리 능력과 스케일 가능성 및 안정성을 목표로 설계

BitNet이 기준 모델에 비해 메모리 사용량과 에너지 소비를 크게 줄이면서도, 혼란도(perplexity) 및 다운스트림 작업 성능 면에서 경쟁력 있는 성과를 달성함

+ Comment

Transformer는 여전히 많은 논문에서 활발히 사용되고 있는 구조이다

LLM 모델의 경량화를 강조하는 요즘, 이러한 연구는 앞으로도 매우 유용할 것이라고 생각한다

728x90