DeepLearing/NLP(FineTuning)

[논문리뷰](24.04) Better & Faster Large Language Models via Multi-token Prediction

notdecidedyet 2024. 12. 11. 17:40

 

0 Abstract

  • 기존: LLM이 한 번에 하나의 다음 토큰만 예측 (next-token prediction)
  • 제안: 한 번에 여러 개의 미래 토큰을 동시에 예측 (multi-token prediction)
    • 언어 모델을 훈련할 때 한 번에 여러 개의 토큰을 예측하도록 하는 것이 효율성이 높다.(코스트, 수렴 속도 등)
    • Shared model trunk와 n개의 독립적인 output heads 사용
  • 성능 : 모델 크기가 클수록 더 효과적이며, 코딩과 같은 생성적 벤치마크에서 더 효과적임.
    • next-token 모델보다 HumanEval에서 12% 더 많은 문제를, MBPP에서 17% 더 많은 문제를 해결

 

1. Introduction

 

  • 현 시점의 LLM은 next token prediction으로 LLM을 훈련하여, 엄청난 양의 지식과 기본적인 reasoning capabilities를 학습
  • 허나, next token prediction으로 지식 및 Reasoning capabilities를 얻기까지는 매우 비효율적인 과정이 들어가있음.
    • 예) teacher forcing과 next-token prediction은 지역적 패턴에 집착하고 hard decisions를 간과하여, 인간과 동일한 수준에 도달하려면 수십배 더 많은 데이터가 필요함.
  • LLMs를 multiple tokens를 한번에 예측하도록 훈련하면, 더 나은 sample efficiency로 이끌 것이라고 생각함. Figure 1 처럼, multi-token prediction은 LLM이 n개의 미래 토큰을 모두 한번에 그리고 병렬로 예측하도록 한다.

 

 

 

 

 

 

 

기여점: Multi-token prediction이 이전 문헌(Qi et al., 2020)에서 연구되었지만, 이 연구는 다음과 같은 기여를 한다.

  1. 훈련 시간이나 메모리 오버헤드가 없는 간단한 multi-token prediction 아키텍처를 제안(Section 2).
  2. 이 훈련 패러다임이 대규모에서 유익하다는 실험적 증거를 제공하며, 13B 파라미터까지의 모델이 평균적으로 약 15% 더 많은 코드 문제를 해결(Section 3).
  3. Multi-token prediction은 self-speculative decoding을 가능하게 하여, 다양한 batch-sizes에서 inference 시간을 최대 3배까지 더 빠르게 만든다(Section 3.2).

 

2. Method

  • 기존 Next token prediction Loss function : minimize $L_1$ loss (CrossEntropy)
    • $L_1 = -\sum_t \log P_\theta(x_{t+1} | x_{t:1})$
      • $P_\theta$ : LLM
      • 목표 : 과거 토큰들 $x_{t:1} = x_t, ..., x_1$이 주어졌을 때 다음 토큰 $x_{t+1}$의 확률을 최대화하도록 학습
  • 제안 : 위 방식을 일반화하여 Multi-token prediction task 구현, 모델은 한번에 n개의 미래 토큰을 예측하도록 지시함.
    • $L_n = -\sum_t \log P_\theta(x_{t+n:t+1} | x_{t:1})$
    • 이를 실용적으로 하기 위해 다음을 한다.
      • 관찰된 컨텍스트 $x_{t:1}$의 잠재 표현 $z_{t:1}$을 생성하는 공유 trunk
      • n개의 미래 토큰을 병렬로 예측하기 위한 n개의 독립적인 heads
    • 위 정보를 바탕으로 multi-token prediction cross-entropy loss를 다음과 같이 유도함.
      • $L_n = \logP_\theta(x_{t+n:t+1}|z_{t:1}) \cdot P_\theta(z_{t:1} | x_{t:1})$
      • $L_n = -\sum_t \sum_{i=1}^n \log P_\theta(x_{t+i}| z_{t:1}) \cdot P_\theta(z_{t:1} | x_{t:1})$
    • 실제로 아키텍처는 다음과 같이 구성:
      • $f_s$ : 관찰된 컨텍스트 $x_{t:1}$에서 hidden representation $z_{t:1}$을 생성하는 shared transformer trunk
      • $f_{h_i}$ : transformer 레이어로 구현된 n개의 독립적인 output heads
      • $f_u$ : 공유 unembedding matrix
    • n개의 미래 토큰을 예측하기 위해 다음을 계산:
      • $P_\theta(x_{t+i}| x_{t:1}) = \text{softmax}(f_u(f_{h_i}(f_s(x_{t:1}))))$
        • $i = 1, ..., n$
        • 예) $P_\theta(x_{t+1} | x_{t:1})$는 next-token prediction head

Memory-efficient implementation

  • Multi-token predictor 훈련의 큰 도전 과제는 GPU 메모리 사용량 감소
  • Vocabulary 크기 V가 latent representation의 차원 d보다 훨씬 크기 때문에 logit 벡터가 GPU 메모리 사용의 병목점이 됨
    • Llama3의 경우 V의 경우 약 128,000(?)으로 기억함. 반면 Latent representation 약 16,000
    • 16000 -> 128,000 맵핑하는 곳에서 병목이 있음
  • $(n, V)$ 모양의 모든 logit과 그 그래디언트를 구체화하는 단순한 구현은 허용 가능한 batch 크기와 평균 GPU 메모리 사용률을 심각하게 제한
  • 이러한 이유로, forward와 backward 연산의 순서를 신중하게 조정하는 아키텍처를 제안

  • 특히 shared trunk $f_s$를 통한 forward pass 후, 각 독립적인 output head $f_i$의 forward와 backward pass를 순차적으로 계산하며 trunk에서 그래디언트를 누적
  • output head $f_i$에 대한 logit(과 그 그래디언트)를 생성하는 동안, 다음 output head $f_{i+1}$로 진행하기 전에 이를 해제
  • 결과적으로 피크 GPU 메모리 사용량을 $O(nV + d)$에서 $O(V + d)$로 감소시켰으며, 런타임에는 영향이 없음
  • 추론시 : 가장 기본적인 사용은 next-token prediction head $P_\theta(x_{t+1} | x_{t:1})$를 사용한 일반적인 자기회귀 예측이며, 다른 heads는 무시한다. 허나, 추가 output heads는 blockwise parallel decoding이나 Medusa-like tree attention과 같은 self-speculative decoding 방법을 통해 디코딩 속도를 향상시키는 데 활용될 수 있음

 

3. Experiments on real data

  • next token predictord와 N-token predictor의 공평한 비교를위해 항상 동일한 양의 파라미터로 비교를 진행. 
    • 4개의 token predictor를 갖고 있고, shared truck의 layer가 32개라면, 32-(4-1)을 하여 29개의 layer만으로 진행

3.1 Benefits scale with model size

 

  • 모델 크기가 커질수록 multi-token prediction의 이점이 더욱 커짐
  • MBPP와 HumanEval에서의 평가 결과, 더 나은 성능을 달성
  • multi-token prediction은 큰모델에서 유용한 방법으로 보임. -> 따라서 지금까지 유용하지 않은 것처럼 보일 수 있었음.

3.2 Faster inference

  • XFormers를 사용하여 heterogeneous batch sizes로 greedy self-speculative decoding 구현
    • Self-speculative decoding : 모델이 t+1 토큰을 예측하며 동시에 그 다음 t+n 토큰들도 예측을 시도
      (구체적인것은 2018논문 찾아볼것)
  • 훈련 중에 보지 않은 코드와 자연어 테스트 데이터셋에서 디코딩 속도 측정, 코드에서 3개 제안 중 평균 2.5개가 수용되어 3.0배 속도 향상, 텍스트에서 2.7배 속도 향상

3.3 Learning global patterns with multi-byte prediction

  • Next-token prediction 작업이 지역적 패턴에 집중한다는 것을 보여주기 위해 byte-level tokenization의 극단적 케이스 실험
  • 314B 바이트(약 116B 토큰에 해당)에 대해 7B 파라미터 byte-level transformer 훈련
    • 8-byte(약 3개 토큰) prediction 모델이 next-byte prediction에 비하여 높은 성능 향상 달성:
      • MBPP pass@1에서 67% 더 많은 문제 해결
      • HumanEval pass@1에서 20% 더 많은 문제 해결
    • 문장을 읽을 때, 한 글자씩 읽는 것이 아니라 여러 글자를 한번에 보고 이해하는 것과 유사한 개념. 이렇게 함으로써 모델이 더 넓은 문맥을 한번에 이해하고 처리할 수 있게 된다.

3.4 Searching for the optimal n

  • 200B 토큰의 코드로 훈련된 7B 모델에 대해 실험을 진행( n = 1, 2, 4, 6, 8)
  • 4 token으로 훈련한것이 일관적으로 높은 성능을 달성하나, APPS데이터에서는 6일경우 좋은 성능을 보임
    • 데이터 분포에  따라 최적의 N이 달라질 수 있음.

3.5 Training for multiple epochs

Multi-token training은 동일한 데이터에 대해 여러 에포크를 훈련할 때도 next-token prediction에 비해 우위를 유지. 단, 성능 향상 폭은 줄어들긴 함.

 

3.6 Finetuning multi-token predictors(Appendix F)

  • Multi-token prediction loss로 사전학습된 모델들은 파인튜닝에서도 next-token 모델들보다 더 나은 성능을 보임.
    • Section 3.3의 7B 파라미터 모델들을 CodeContests 데이터셋에서 파인튜닝하여 실험
      • 4-token prediction 모델과 next-token prediction 베이스라인을 비교, 4-token prediction 모델에서 추가 prediction heads를 제거하고 classical next-token prediction target으로 파인튜닝하는 설정도 포함
      • 결과 : 4-token prediction 모델을 파인튜닝하는 두 방식 모두 모든 pass@k 메트릭에서 next-token prediction 모델보다 더 나은 성능을 보임
        • 여기서 두 방식이란
          • n' = 4 : 4 token prediction으로 학습하고 4Token prediction로 파인튜닝
          • n' = 1 : 4 token prediction으로 사전학습된 모델에서 prediction heads를 제거하고 classical next-token prediction(즉, 1-token prediction)으로 파인튜닝
      • 4-token prediction 사전학습 하고 next-token prediction 파인튜닝을 하는 것이 가장 좋은 방법으로 보이며, 보조 태스크로 사전학습한 후 태스크별 파인튜닝을 하는 전통적인 방법들과 일치함. (자세한 내용은 Appendix F)

3.7 Multi-token prediction on natural language

  • 200B 토큰의 자연어로 4-token, 2-token, next-token prediction loss를 각각 사용하여 7B 파라미터 모델들을 훈련
  • 6가지 데이터 셋에 대해서 실험을 진행하였음 (대부분 Multiple choice 문제)
    • 2-future token prediction 모델이 next-token prediction 베이스라인과 비슷한 성능을 보임.
    • 6개 데이터 셋은 주로 multiple choice 형식의 문제나 likelihood 기반의 평가를 사용하나, 논문의 저자들은 이러한 평가 방식이 언어 모델의 생성 능력을 효과적으로 측정하기에는 적합하지 않다고 판단
      -> 추가로 요약 테스크와, 수학문제 해결 능력 평가하는 실험을 진행

  • 요약 태스크의 경우:
    • 8개의 벤치마크에서 ground-truth 요약에 대한 ROUGE 메트릭으로 평가.
    • 각 벤치마크의 훈련 데이터셋에서 3 에포크 동안 각 사전학습 모델을 파인튜닝
    • 검증 데이터셋에서 가장 높은 ROUGE-L F1 점수를 가진 체크포인트 선택
    • n = 2와 n = 4 모두 훈련 데이터셋 크기와 관계없이 next-token 베이스라인보다 ROUGE-L F1 점수가 향상됨
    • 큰 데이터셋 크기에서는 성능 격차가 줄어듦
  • 수학의 경우:
    • GSM8K 벤치마크에서 8-shot 모드로 사전학습 모델 평가
    • Chain-of-thought로 유도된 최종 답변의 정확도 측정
    • 코드 평가와 같이 답변의 다양성과 정확성을 정량화하기 위해 pass@k 메트릭 사용
    • 200B 토큰 후에는 2-token prediction 모델이 next-token 베이스라인보다 명확한 우위
    • 500B 토큰 후에는 순서가 역전됨 (next-token이 우수)
    • 4-token prediction 모델은 전반적으로 성능이 좋지 않음

4. Ablations on synthetic data

  • "multi-token prediction이 성능을 향상시킨다"는 결과가 정확히 왜, 어떻게 발생하는지 이해하기 위한 실험을 진행
  1. 작은 크기의 모델에서의 발견:
    • Induction capability(이전에 나온 패턴을 기억하고 활용하는 능력)가 크게 향상됨
    • 일반적인 next-token 방식으로는 이런 능력이 잘 생기지 않았지만, multi-token 방식에서는 이 능력이 잘 형성됨
  2.  수학 문제 해결에서의 발견:
    • Multi-token prediction을 사용하면 수학 문제 해결 능력이 향상됨
    • 이 향상 정도는 모델의 크기를 3배 늘렸을 때보다도 더 컸음
    • 즉, 단순히 모델을 크게 만드는 것보다 multi-token 방식을 사용하는 것이 더 효과적이었다는 의미

4.1 Induction capability

  • Induction : 간단한 추론 패턴을 의미함.
    • 예시) 문장에 "AB"가 포함되어 있고 나중에 "A"가 언급되면, induction은 다음에 "B"가 올 것이라고 예측하는 것
  • 실험 :
    • 모델 : 1M ~ 1B까지 임베딩 모델이 없는 작은 모델들
    • 데이터 : child 이야기 데이터셋
    • 테스트 : 원본 테스트셋에서 100개 선택해서 이름 수정
      • 이름 수정 : 원래 이름 "John" → 무작위 생성된 "Token1 Token2"
    • 실험 방법 : 이름이 최소 한 번 언급된 후 각 이름의 두 번째 토큰을 예측하는 것은 순수한 induction 태스크로 볼 수 있음
      • 이름이 최소 한 번 이상 문장에서 등장한 후에만 측정, 모델이 이전에 본 패턴을 기억하고 재현할 수 있는지를 테스트
    • 실험에서는 최대 90 에포크까지 훈련하고 테스트 메트릭에 대해 early stopping을 수행
  • 결과 : 
    • 작은 모델(30M 파라미터 이하): multi-token prediction이 매우 효과적, 큰 모델(100M 파라미터 이상): 이점이 감소
    • Multi-token prediction은 모델이 단어들의 위치 간의 관계를 학습하게 촉진하는 것으로 보임.
      • induction heads와 다른 in-context learning 메커니즘의 형성에 도움이 됨
      • 허나, induction capability가 일단 형성되면, 이러한 학습된 특징들은 induction을 현재 토큰에서 해결할 수 있게됨
        -> 이 시점부터는 multi-token prediction이 실제로 이 제한된 벤치마크에서 성능을 저하시킬 수 있음
      • 허나, 이는 3.1의 결과와는 상반됨
        -> 3.1의 결과에서는multi-token prediciton이 더 기여할 수 있는, 좀더 고차원적인 in-context reasoning이 있다고 추측한다.

4.2 Algorithmic reasoning

  • induction보다 더 복잡한 형태의 In context reasoning을 측정하기 위해 algorithmic reasoning task 수행.
  • 결론 : 

 

  • 다중 토큰 예측으로 훈련하면 작업의 난이도에 관계없이 정확도가 증가
  • 도메인 외 일반화 성능을 상당히 개선시키지만, 여전히 성능 낮음

 

 

 

 

 

 

 

 

 

 

 

 

 

 

  • 허나, 모델 크기를 세배로 늘리는 것보다, 다중 토큰을 사용하는 것이 효과가 더 좋음.

 

 

 

 

 

 

 

 

 

 

 

 

 

 

5. Why does it work? Some speculation

직관적으로 설명하면, 훈련-평가 불일치(mismatch)를 완화함으로서 성능 향상을 일으킴. 불일치의 원인으로는 훈련에서는 Teacher forcing을 진행하는데반에 추론에서는 autoregressive한 생성을 진행하기 때문이다. 

 

5.1 Lookahead reinforces choice points

텍스트 생성에서 모든 토큰들이 중요한 것은 아님. 일부 토큰은 약간 바뀌어도 문장의 내용에 큰 영향을 끼치지 않는다. 반면, 선택 지점(choice points)으로 불리는 특정 토큰은 문맥, 흐름에 중요한 영향을 끼침. 

multi token prediction 훈련은 각 토큰이 후속 토큰과 얼마나 연관이 있는지에 따라 중요한지 판단하여, 중요한 토큰에 더 많은 학습 비중(가중치)를 둔다.

 

예를 들어, 특정 토큰이 예측하기 어려운 부분이라고 하자. 특정 토큰의 순서가 예측하기 어렵기에, 당연하게도 이후에 결과들 또한 예측이 불가하다. 허나 다중 토큰 예측 손실은 이러한 분기점에 더 많은 학습 비중을 두어 모델이 해당 지점을 잘 처리하도록 학습되었다.

 

쉽게 말해, 다중 토큰 예측은 모델이 이러한 중요한 선택 지점을 더 잘 다룰 수 있도록 학습을 강화한다.

 

 

 

 

 

 

 

 

5.2. Information-theoretic argument

정보이론적으로 설명하였는데 생략함.

 

6. Related Work

Language modeling losses

Multi-token prediction in language modelling

 

 

Self-speculative decoding

Multi-target prediction