DeepLearing/NLP(FineTuning)

[논문리뷰] Scaling Laws for Precision

notdecidedyet 2024. 12. 5. 22:16

0 Abstract

  1. 연구 배경 및 문제점
    1. Low precision training과 inference가 language model의 품질과 비용에 영향을 미침
    2. 하지만 현재 scaling law들은 이를 고려하지 않음
  2. 핵심 발견
    • 낮은 precision으로 training하면 모델의 effective parameter count가 감소함
    • Post-training quantization으로 인한 성능 저하는 학습 데이터가 많아질수록 증가함
      • 이는 오히려 추가적인 pretraining 데이터가 해로울 수 있다는 것을 의미함
    • 더 큰 모델을 더 낮은 precision으로 training하는 것이 compute optimal일 수 있음
  3. 제안
    • Training과 inference를 위한 "precision-aware" scaling law 개발
    • Post와 pretraining quantization에 대한 통합된 단일 functional form 도출

 

1 Introduction

  • 최근 딥러닝
    • Deep learning에서 scale이 발전의 핵심 동력으로 부상
      Scale (Model, Dataset, Computational Scale)
    • 기존 연구는 모델/데이터셋 크기의 trade-off에 초점
    • Precision이 비용과 성능에 영향을 미치는 중요한 제3의 요소임에도 불구하고 충분히 연구되지 않음
  • 최신 Deep Learning의 Precision 동향
    • Llama-3 시리즈: BF16 사용해서 학습
    • FP8, FP4, Binary/ternary scale training 시도
  • Questions : 
    • What are the tradeoffs between precision, parameters and data?
    • How do they compare for pretraining and inference?
  • Precision과 Scaling을 같이 연구하는 것은 까다로움.
    • Scaling은 일반적이고 추상적인 수학적 공식(functional form) 추구하며 구체적인 구현 방법에 대해서는 논의하지 않음. 예를들어 "모델 크기가 2배가 되면 성능은 1.5배 향상된다" 같은 일반적인 법칙 도출
    • 반면, Precision은 매우 구체적인 세부사항에 집중한다. 예를들어, 어떤 레이어를 몇 비트로 줄일 것인가?, weight를 어떤 방식으로 반올림할 것인가?, activation function은 어떻게 quantize할 것인가?
    • 이 연구에서는, Quantization의 복잡한 기술적 세부사항들을 "loss scaling"이라는 더 단순한 개념으로 추상화
  • 주요 연구 결과
    • Post-train quantization 관련
      • 특정 시점 이후, training 데이터셋을 바탕으로 훈련하는 것은 해로울 수 있음
      • 파라미터와 데이터에 따른 성능 저하 패턴 발견
    • Quantized Training 관련( 아래 내용들은 추후 서술)
      • quantization-aware-training(weights only)과 low-precision training(weights, activations, attention 모두 quantized) 에 대해서 실험 
      • Compute-optimal pretraining precision은 일반적으로 compute budget과 독립적
      • 모델 크기 제한 시 compute-optimal precision이 compute에 따라 점진적 증가
        • 더 많은 컴퓨팅 자원을 사용할수록, 최적의 precision도 점차 높아져야 함
  • 연구 규모 및 방법
    • 465개 language model 실험
    • 3-16 bit precision 범위
    • Multiple precision으로 post-train quantize 테스트
    • $N$개의 파라미터를 가진 language model에 대해, $D$개의 토큰으로 training precision $P_{train}$으로 학습하고, post-train weight precision $P_{post}$로 quantize한 경우, 다음과 같은 형태의 unified scaling law를 찾음.
      (Loss를 대충 계산해볼 수 있음.

  • 수식
    • $A$, $B$, $E$, $\alpha$, $\beta$는 양의 fitted constant
    • $\delta_{PTQ}$ : inference 전 post-training quantization으로 인한 loss 저하를 의미
  • post-train quantization에 대한 결과는 더 많은 pretraining FLOPs가 항상 inference-time에 더 나은 모델로 이어지지는 않는다는 것을 보여주며,
    -> 단순히 컴퓨팅 리소스를 늘리는 것이 항상 최선의 전략이 아님. Precision, 데이터 크기, 모델 크기 간의 적절한 균형이 필요
  • post-train quantization시 low-precision pretraining에 대한 결과는 16-bit로 모델을 training하는 표준 관행과, 극도로 낮은(4-bit 이하) pretraining precision으로 경우 모두 최적이 아닐 수 있음을 시사함.

2. Background, Related Work, and Setup

  • Notation
    • $D$는 토큰 단위의 데이터셋 크기를 나타냄
    • $N$은 파라미터 단위의 모델 크기를 나타냄
    • $P_w$, $P_a$, $P_{kv}$는 각각 training 중의 weight, activation, key-value cache("attention")의 integer-type bit precision을 나타냄
    • $P_{post}$는 모델 inference를 위해 post-train quantize하는 precision
    • $P$ 또는 $P_{train}$이 모델의 특정 부분을 지칭하지 않고 사용될 때는, 모든 부분이 동일한 precision으로 묶여있음을 의미
    • $\delta_{PTQ}(N, D, P_{train}, P_{post})$ : post-train quantization으로 인한 inference-time loss. 
      pretrain끝난 시점의 모델과, post-training quantization을 한 것과의 loss 변화

2.1 Quantization Fundamentals: How, What, When

The Problem: Compute vs Memory-Bound Workloads.

  • 병목(Workloads) : 
    • Compute : 행렬 곱셈 연산 자체의 속도 제한
    • Memory : GPU 내부에서 데이터를 이동시키는 속도 제한
    • 병목현상은 workload의 종류에 따라 다르게 나타남.
      • Pretraining: 대규모 행렬 곱셈이 많아서 compute가 병목
      • Small-batch inference: 모델 가중치를 메모리에서 불러오는 것이 병목
      • 긴 시퀀스 처리: KV cache 관련 메모리 작업이 병목
    • 이런 다양한 병목현상이 존재하기 때문에, 연구에서는 모델의 세 가지 주요 구성 요소(weights, activations, KV cache)의 precision을 각각 따로, 그리고 함께 연구할 필요성이 있다는 것을 설명
  • Quantization: How.
    • 연산의 quantization은 일반적으로 forward/backward pass에서 어떤 계산에 관련된 행렬의 값을 반올림하는 것을 의미하며, gradient는 high/full precision으로 accumlate된다. Quantization은 일반적으로 integer나 floating-point type으로 수행된다.
    •  
    • 더보기
      부연 설명

      Forward Pass : 
      - 연산 수행시 16bit 8bit 4bit등의 낮은 Precision 사용
      - Matrix multiplication등 연산이 더 적은 메모리와 계산으로 수행

      Backward Pass : 
      - Gradient 계산은 high/full precision으로 수행
      - Gradient accumulation도 high/full precision으로 유지 <- 학습의 안정성에 매우 중요한 역할을 함.
  • Quantization : What
    • Only weights. "Quantization-aware training" Training 
      • Weight만 Quantize
        예를들어 8bit으로 quantize
      • Matrix multiplication은 High Precision사용
        예를들어 8bit quantize weight을 -> 32bit으로 변환하고 Multiplication진행
        따라서, Quantize안한 것과 동일한 compute cost를 갖음.
        이렇게 하는 이유는 Training 중에 quantization의 영향을 미리 경험하여, 모델이 낮은 precision에 적응하도록 함
      • 더보기
        예)
        - Weight: 32-bit float
        original_weight = 3.14159265359

        - 1. Quantize step (8-bit로 quantize)
        quantized_weight = 3  # 정밀도 손실 발생

        - 2. Dequantize step (다시 32-bit float으로)
        dequantized_weight = 3.0  # 32-bit이지만 quantize 효과 반영된 값

        - 3. 실제 matrix multiplication
        output = matrix_multiply(dequantized_weight, input) # (dequantized_weight(3.0)를 사용해 32-bit precision으로 연산 수행)
    • Weights, activations, attention. "Low-precision training"
      • Weight, activation, attention을 quantize하면 모두 같은 precision에 있기 때문에 matrix multiplication을 low precision에서 수행할 수 있어 compute 에서 이득을 얻을 수 있음.
  • Quantization : When.
    • Quantization : training 중이나 후에 수행
      • inference-time memory 비용을 줄이고자 할 때, 먼저 post-train quantization을 시도한다
        • 만약 post-train quantization이 모델의 성능을 너무 많이 저하시킨다면, quantization-aware-training이 사용됨.
      • Post-train quantization은 일반적으로 model weight에만 적용

2.2 Scaling Laws and Parametric Fits

Scaling Laws.

  • 과거 연구에서, $L(N, D) = AN^{-\alpha} + BD^{-\beta} + E$ 형태의 function form을 사용하여 loss scaling을 모델링 함.
    • 여기서 $A$, $B$, $\alpha$, $\beta$, $E$는 양의 fitted constant(특정 상수 값)
    • 더 많은 compute이 사용 가능해질 때 데이터와 파라미터가 대략 동일한 비율로 scaling되어야 한다는 것을 발견
      • 예) compute 4배로 늘릴 수 있으면, 모델 크기도 2배로 늘리고, 학습 데이터도 2배로 늘리는 것이 최적임
      • Chinchilla-optimal 또는 Chinchilla라고 지칭
      • $D/N \approx 20$이 pretraining compute-optimal이라는 의미로 구어적으로 사용
  • 과거 연구들 : 
    • Noise Impact : 모델이나 데이터에 noise가 추가될 때 loss가 어떻게 변화하는지 연구, quantization을 일종의 noise로 볼 수 있고, 이로 인해 Loss가 어떻게 변할지 대략적으로 알 수 있음.
    • Total Model Bits관점 : total bits와 모델 성능 사이의 관계를 연구
    • Knowledge Capacity관점 : 모델이 저장할 수 있는 지식의 양(knowledge capacity)과 quantization의 관계를 연구, Quantization이 모델의 knowledge capacity를 어떻게 제한하는지 분석

Overtraining : 

  • inference 비용을 고려하여 더 작은 모델을 Chinchilla-optimal보다 상당히 더 오래 training하는 것을 의미함
  • 예를 들어, Llama-3-8B는 $D/N \approx 2000$까지 학습되었고 Gemma-2 시리즈는 $D/N > 1000$까지 학습. 이 논문에서 이러한 모델들을 "overtrained"라고 지칭하며, token/parameter 비율 $D/N$이 전반적으로 판단 기준이 된다.
    • Chinchilla-optimal : 모델 크기(N)와 데이터 크기(D) 사이의 이상적인 비율을 약 1:20
      • 예: Llama-3-8B는 파라미터 1개당 약 2000개의 토큰으로 학습됨 (Chinchilla-optimal의 100배)
  • 따라서, 이 연구에서도, $D/N \approx 10^3$까지의 실험을 수행, 이 논문의 아이디어인 scaling law가 찾은 예측을 $D/N \approx 10^5$까지 분석한다.

2.3 Setup

  • 데이터 : Dolma V1.7 데이터셋에서 OLMo-style모델의 suite을 학습 및 평가
  • 모델 : 표준 Transformer++ 사용 (파라미터는 Appendix A 참조)
  • $N \in [30, 60, 110, 220]$ 백만 파라미터(non-embedding 임베딩 제외)와 $D \in [1.5, 3, 6, 13, 26]$ 십억 토큰에 걸친 language model pretraining 의 조합으로 실험 진행
  • (weights, activations, attention) 에 대해서 조합으로 실험 진행(8가지 조합)

 

3. Scaling Laws for Post-Train Quantization

  • 가장 쉬운 quantization 기술 : off-the-shelf 모델을 post-train quantize
  • 이 섹션에서는 BF16으로 학습된 모델들을 살펴보고, 이 모델들에 GPTQ를 사용하여 post-train quantization을 수행
  • (Appendix F)Quantization loss 저하 $\delta_{PTQ}$를 정량화.
    • Quantization 전 모델의 loss와 Quantization 후 모델의 loss의 차이로 정량화를 계산
    • 일반적으로는 더 많은 데이터로 pre-training하면 모델 성능이 좋아질 것으로 기대하나, 이렇게 학습한 모델의 경우 quantization의 성능 하락폭이 더 컸음.

3.1 Overtrained Models Degrade more when Post-Train Quantized

  • N : 모델 크기
  • D : 데이터 크기
  • x-axis : D/N
  • 1st row : Pretrain후 Post quantization을 했을 경우의 Loss
  • 2nd row : Pretrain후 Post quantization과의 Loss차를 구한 값
  • 발견 : 
    • 데이터 크기와 성능 저하의 관계 : 
      • $\delta_{PTQ}$ (quantization으로 인한 성능 저하)는 training 데이터가 많아질수록 증가
      • 이는 모든 크기의 모델에서 일관되게 관찰됨 (2nd Row) 단 모델의 크기가 커질수록 Degradation의 기울기가 다름.
    • 모델 크기와 성능 저하의 관계 (2nd Row)
      • 같은 양의 데이터로 훈련했을 때, 더 큰 모델이 quantization으로 인한 성능 저하가 더 적음
    • Precision의 영향:
      • quantization precision을 낮출수록 (즉, 더 적은 비트를 사용할수록), 성능 저하가 지수적으로(exponentially) 증가
  • 위 발견을 바탕으로 수학적 수식을 도출함.
    • $\delta_{PTQ}(N, D, P_{post}) = C_T \left(\frac{D^{\gamma_D}}{N^{\gamma_N}}\right)e^{-P_{post}/\gamma_{post}}$
    • 여기서 $C_T$, $\gamma_D$, $\gamma_N$, $\gamma_{post}$는 양수이며 fitted constant이다.
    • 여기서, $\gamma_D$와 $\gamma_N$의 fitted 값이 비슷하여, $\frac{D^{\gamma_D}}{N^{\gamma_N}}$이것을 $\frac{D/N}$으로 근사적으로 볼 수 있음.
    • 이 부분을 해석하면, 모델이 더 많은 데이터로 학습될수록 더 많은 정보를 weight에 저장하게 되어, 각각의 Weight더 많은 중요한 정보를 담게 된다. 따라서 이 Weight값들을 quantization하면 저장된 중요 정보들이 손상될 수 있음.
    • 따라서 더 많은 데이터로 학습된 모델이 quantization에 더 취약함.

 

4. Scaling Laws for Quantized Training

  • 의도
    • Training상황에서, 세 가지 주요 구성 요소(weights, activations, KV cache)를 각각 다른 precision으로 실험
    • 각 구성 요소의 precision을 3-bit부터 12-bit까지 다양하게 테스트
    • 비교를 위해 BF16(Brain Floating Point 16-bit) 버전도 함께 학습
  • Details
    • 학습 시의 precision만 변경
    • 테스트 시의 precision은 고정

4.1 Quantization-Aware-Training: Quantizing Weights During Training has a Consistent and Predictable Effect

  • 실험 설정
    • 먼저 activation($P_a$)과 key-value cache($P_{kv}$)의 precision은 높게 고정
    • 데이터셋 크기는 13B 토큰으로 고정
    • Weight precision($P_w$)과 모델 크기($N$)를 다양하게 변화시키면서 실험
  • 발견
    • 왼쪽 그래프 ($N_{eff}/N$ vs Precision):
      • Y축은 $N_{eff}/N$ 비율을 나타내며, 이는 effective parameter count가 실제 parameter count에 비해 얼마나 되는지를 의미함
      • 파란선 : Precision이 높아짐에 따라 $N_{eff}/N$ 비율이 1에 가까워짐.
        -> 모델이 가진 실제 파라미터를 더 효과적으로 활용
    • 중앙 그래프 
      • X축 :  모델 크기(N, millions)
      • Y축 : training 중의 weight precision(bits)
      • 각 contour line은 동일한 loss 값을 가지는 지점들을 연결
      • 해석 : 모델 크기와 precision 사이의 trade-off 관계를 보여줌
    • 오른쪽 그래프 
      • Equation 3을 사용한 이론적 예측 그래프
    • Trade-off 관계 발견: 모델의 weight precision과 파라미터 수 사이에 관계가 있음
      • 예: Loss(적은 파라미터 + 높은 precision) = Loss(많은 파라미터 + 낮은 precision). 두 조합은 비슷한 성능을 달성할 수 있음.
    • Precision변화 효과(Figure3 첫번째) : 
      • 낮은 precision 구간: precision을 조금만 올려도 성능이 크게 향상됨
      • 높은 precision 구간(6-7 bits 이상): precision을 더 올려도 성능 향상이 미미함(포화 현상)
        -> 8bit이 최적인가? 4Bit quantization의 경우, Double Quant를 진행했을까?
    • effective parameter count ($N_{eff}$)라는 개념을 제안
      • $N_{eff}(N, P_w) = N(1 - e^{-P_w/\gamma_w})$
        • $N$: 실제 파라미터 수
        • $P_w$: weight precision
        • $\gamma_w$: weight의 precision 민감도를 나타내는 상수
        • $e^{-P_w/\gamma_w}$: precision에 따른 감소 factor
    • 최종 Scaling Law : Chinchilla scaling law 확장
      • $L(N, D) = A[N(1 - e^{-P_w/\gamma_w})]^{-\alpha} + BD^{-\beta} + E$            (3)
        • $N$ 대신 $N_{eff}$를 사용하여 precision의 효과를 반영

 

4.2 Low-Precision-Training: The Effects of Quantizing Weights, Activations, and Attention are Compositional and Multiplicative

  • 현대 GPU에서 Matrix Multiplication을 하기 위해서는 같은 Precision에 있어야 계산이 가능함.
    Weight, Activation, KV등이 모두 같은 Precision이여야 함.
  • $P_a$와 $P_{kv}$의 scaling 동작도 분석 시작

Precision of activations and KV cache affect loss in a similar way

  • activation precision($P_a$)과 key-value cache precision($P_{kv}$)을 각각 독립적으로 변화
  • 두 요소의 scaling 동작이 weight precision($P_w$)의 scaling 동작과 매우 유사한 패턴을 보임
  • Equation 3의 형태($N(1 - e^{-P/\gamma})$)로 표현될 수 있음.

Constants fitted marginally and jointly make similarly good predictions

  • weight, activation, attention 사이의 상호작용을 살펴봄
  • 먼저, 모델의 한 부분(weight, activation, 또는 KV cache)의 precision만 변화시키면서 관찰
  • 이후, 여러 구성요소(weight, activation, KV cache)의 precision을 동시에 low precision적용하여 관찰
  • 독립성을 확인하기 위해, marginally fitted constant를 가진 모델과 jointly fitted constant를 가진 모델의 예측력을 비교하여 테스트
  • Joint fit과 Combined Marginals가 거의 동일한 성능을 보이는 것은 구성요소들이 실제로 독립적으로 작용한다는 증거
    • Marginal Sweep : weight precision과, 추정식을 통해서 fitness가 매우 높음을 알 수 있음.
    • Joint fit : Weight, Activation, KV cache precision을 모두 동시에 변화시키면서 분석, 이것도 Fitness가 매우 높음
    • combined Marginals : 각 구성요소를 독립적으로 계산하고 결과를 곱하여 구한 값. 여전히 Fitness가 매우 높음.

  • Weight, activation, KV cache를 training 중에 quantize하는 효과는 독립적이고 곱셈적인 것으로 잘 모델링 된다

 

4.3 Implications For Pretraining

  • Precision과 Compute의 관계
    • 모델을 Precision $P$로 training할 때 ($P_w = P_a = P_{kv} = P$ 인 경우), Compute 비용은 $P$에 선형적으로 증가
    • 예: 8-bit로 training하면 16-bit에 비해 약 절반의 compute 비용
  • Chinchilla 비용 모델의 일반화
    • 기존 Chinchilla: $C = 6ND$ FLOPs (16-bit 기준)
    • 일반화된 모델: $C = \frac{6}{16}NDP$ : 16-bit를 기준으로 normalize하기 위해 16으로 나눔
    • $\min_{N,D,P} L(N, D, P) = A[N(1 - e^{-P/\gamma})^3]^{-\alpha} + BD^{-\beta} + E$
      subject to $C = \frac{6}{16}NDP$
      • $A$, $\gamma$, $\alpha$, $B$, $\beta$, $E$ : fitted constant
      • $N$: 모델 크기 (파라미터 수)
      • $D$: 데이터 크기 (토큰 수)
      • $P$: Precision (bits)
      • $(1 - e^{-P/\gamma})^3$: Precision으로 인한 effective 파라미터 감소 효과
      • $C$: 총 compute 예산
    • 해석 : 주어진 compute내에서 최적의 Loss를 구할 수 있는, (model size / dataset size / precision ) 조합을 찾는 것

 

4.3.1 If You Must Train in Low Precision, Increase Parameters Before Data

  • $P$가 고정되고 $C \propto NDP$인 제약 조건 하에서 $L(N, D)$를 최소화
  • 최적화 식은 아래와 같다.
    • $\frac{N^(P, C)}{N_{Ch}(C)} \propto [1 - e^{-P/\gamma}]^{-\frac{3\alpha}{\alpha+\beta}}P^{-\frac{\beta}{\alpha+\beta}}$ 그리고 $\frac{D^(P, C)}{D_{Ch}(C)} \propto [1 - e^{-P/\gamma}]^{\frac{3\alpha}{\alpha+\beta}}P^{\frac{\beta}{\alpha+\beta}}$
  • training의 precision이 감소함에 따라 data는 줄이고 parameters는 증가시켜야 함을 시사
  • 매우 낮은 precision에서는 effective parameter count가 0에 가까워지므로, data가 effective parameters를 지나치게 많이 초과하기 때문에 parameter count를 증가시키는 것이 compute-optimal임

4.3.2 Compute-Optimal Pretraining Precision is in General Independent of Compute

  • $C \propto NDP$인 제약 조건 하에서 $L(N, D, P)$를 함께 최소화
  • 최적화 문제를 수학적으로 풀면, 최적의 precision($P^*$)이 compute budget($C$)과 무관하다는 것을 발견
    • 실제로 실험해보니 이 최적의 precision은 약 7 bits
  • 의미
    • 모델을 학습할 때, compute budget이 얼마든 간에 항상 약 7 bits precision을 사용하는 것이 최적
    • 더 많은 계산 자원이 생기면, precision을 올리는 대신 모델 크기($N$)나 데이터 크기($D$)를 늘리는 것이 더 효율적
      • 현재 산업 표준인 16-bit(BF16) training이 최적이 아닐 수 있다는 것을 시사
      • 4-bit 이하의 매우 낮은 precision으로 가는 것도 좋지 않을 수 있다는 것을 보임
    • 4비트로 학습시키면 같은 성능을 내기 위해 모델 크기를 4배 이상 키워야 한다. 결국 리소스 측면에서 비효율적이라는 것을 의미

  • 모델 크기: 220M부터 1.6B 파라미터까지의 다양한 크기의 모델 테스트
  • Precision 범위: FP4(4-bit floating point)부터 FP32(32-bit floating point)까지 테스트
  • 4-bit precision까지는 integer type으로 예측한 결과가 실제 floating-point 실험과 잘 맞음
  • 4-bit 이하에서는 차이가 발생하기 시작했는데, 이는 두 타입의 근본적인 차이 때문:
    • Integer type: 모든 비트가 동일한 가치를 가진다고 가정
    • Floating-point type: 비트가 두 부분(지수부/exponent와 가수부/mantissa)으로 나뉘며, 각각이 다른 방식으로 정확도에 영향을 줌
  • 연구팀이 예측한 최적의 Precision은 7bit이였는데, 실험에서도 비슷하게 나왔음.

4.3.3 But Compute-Optimal Pretraining Precision Can Increase in Compute if Model Size N is Constrained

  • $N$(모델 크기)이 고정되어 있는 상황
  • 기존에는 모든 크기의 모델을 같은 precision으로 학습시키는 것이 일반적
    • 하지만 연구 결과, 모델 크기에 따라 다른 precision을 사용하는 것이 더 optimal할 수 있음
    • 구체적으로, compute-optimal precision은 $\log C$에 비례하여 증가
  • 고정된 모델 크기 상황
    • 모델 크기($N$)가 미리 정해져 있을 때 : 최적의 precision은 compute 예산에 비례하여 로그 스케일로 증가
  • Overtraining 상황에서
    • 많은 데이터로 "과학습"(overtrained)되는 경우, 오히려 더 높은 precision으로 학습하는 것이 compute 효율적
  • 결론 : 
    • 모델 설계 시 무조건 낮은 precision(예: 4-bit)이나 높은 precision(16-bit)을 사용하는 것이 아니라, 상황에 따라 적절한 precision을 선택
    • 특히 대규모 데이터로 학습하는 경우, 흔히 생각하는 것과 달리 더 높은 precision을 고려할 것

 

5. A Unified Scaling Law for Precision

  • training/post-training을 모두 예측하는 unified functional form으로 결합
  • 함수 $\delta_{PTQ}$를
    • $\delta_{PTQ}(N, D, P_{post})$에서 $\delta_{PTQ}(N, D, P_{train}, P_{post})$로 다룸
  • 뭐 어쨌든 결론적으로, 낮은 precision에서 학습한 모델일수록 post-train quantization에서 성능 저하가 낮다.
    • Robustification
      • 모델을 낮은 precision으로 학습시키면, 모델이 처음부터 quantization noise에 적응하면서 학습됨
      • 험한 환경에서 단련된 것처럼 작용하여, 나중에 post-train quantization을 할 때 더 잘 견딜 수 있게 됨
        • post-train quantization : pretrain model을 추론시에 quantization을 하는 것을 의미(추가적인 학습은 없음)
    • Overtraining
      • 낮은 precision으로 학습하면 모델의 effective parameter count가 감소함
      • -> effective parameter count가 낮을수록 post-train quantization으로 인한 성능 저하가 더 커짐

Two  Competing effects at play during post-train quantization

  • $P_w$, $P_a$, $P_{kv}$ 중 어느 것이든 low precision으로 training하면 모델이 "quantization noise"에 강건한 weight를 학습하도록 강제하므로 PTQ 하에서 덜 저하됨.
  • 그러나, low precision으로 학습된 모델의 감소된 $N \mapsto N_{eff}$ effective parameter 수는 Section 3에서 발견한 것처럼 low precision으로 학습된 모델이 더 많이 저하될 것임이 생각됨
  • 정리하자면, low precision으로 학습하면 quantization 후 (긍정적) 성능 하락 적을 것으로 생각되나, (부정적) effective paramter 수늘 줄어들어 모델 성능이 줄어들 것임. 즉 두개의 효과가 서로 배반 적임.
  • 허나, low precision으로 학습하면 quantization 후 (긍정적) 성능 하락 적을 것의 영향도가 더 큰것으로 확인됨.

Modifying $\delta_{PTQ}$ to account for training precision

  • Training precision은 항상 inference precision보다 커야 함
    • 두 precision이 같을 경우 성능 저하는 0이라고 정의

  • $(N, D)$를 고정했을 때, training precision($P_w$)과 inference precision($P_{post}$) 사이에 어떤 차이라도 있으면 성능 저하가 급격히 발생
    -> 이 저하는 지수적으로 증가하는 패턴을 보임 ( Figure 7의 중간 부분 )
  • 이 발견을 바탕으로 아래와 같이 formula를 변경시켰음

 

An interpretable, unified functional form

 

6. Conclusion and Limiations