DeepLearing/NLP(FineTuning)
[논문리뷰] Scaling Laws for Precision
notdecidedyet
2024. 12. 5. 22:16
0 Abstract
- 연구 배경 및 문제점
- Low precision training과 inference가 language model의 품질과 비용에 영향을 미침
- 하지만 현재 scaling law들은 이를 고려하지 않음
- 핵심 발견
- 낮은 precision으로 training하면 모델의 effective parameter count가 감소함
- Post-training quantization으로 인한 성능 저하는 학습 데이터가 많아질수록 증가함
- 이는 오히려 추가적인 pretraining 데이터가 해로울 수 있다는 것을 의미함
- 더 큰 모델을 더 낮은 precision으로 training하는 것이 compute optimal일 수 있음
- 제안
- Training과 inference를 위한 "precision-aware" scaling law 개발
- Post와 pretraining quantization에 대한 통합된 단일 functional form 도출
1 Introduction
- 최근 딥러닝
- Deep learning에서 scale이 발전의 핵심 동력으로 부상
Scale (Model, Dataset, Computational Scale) - 기존 연구는 모델/데이터셋 크기의 trade-off에 초점
- Precision이 비용과 성능에 영향을 미치는 중요한 제3의 요소임에도 불구하고 충분히 연구되지 않음
- Deep learning에서 scale이 발전의 핵심 동력으로 부상
- 최신 Deep Learning의 Precision 동향
- Llama-3 시리즈: BF16 사용해서 학습
- FP8, FP4, Binary/ternary scale training 시도
- Questions :
- What are the tradeoffs between precision, parameters and data?
- How do they compare for pretraining and inference?
- Precision과 Scaling을 같이 연구하는 것은 까다로움.
- Scaling은 일반적이고 추상적인 수학적 공식(functional form) 추구하며 구체적인 구현 방법에 대해서는 논의하지 않음. 예를들어 "모델 크기가 2배가 되면 성능은 1.5배 향상된다" 같은 일반적인 법칙 도출
- 반면, Precision은 매우 구체적인 세부사항에 집중한다. 예를들어, 어떤 레이어를 몇 비트로 줄일 것인가?, weight를 어떤 방식으로 반올림할 것인가?, activation function은 어떻게 quantize할 것인가?
- 이 연구에서는, Quantization의 복잡한 기술적 세부사항들을 "loss scaling"이라는 더 단순한 개념으로 추상화
- 주요 연구 결과
- Post-train quantization 관련
- 특정 시점 이후, training 데이터셋을 바탕으로 훈련하는 것은 해로울 수 있음
- 파라미터와 데이터에 따른 성능 저하 패턴 발견
- Quantized Training 관련( 아래 내용들은 추후 서술)
- quantization-aware-training(weights only)과 low-precision training(weights, activations, attention 모두 quantized) 에 대해서 실험
- Compute-optimal pretraining precision은 일반적으로 compute budget과 독립적
- 모델 크기 제한 시 compute-optimal precision이 compute에 따라 점진적 증가
- 더 많은 컴퓨팅 자원을 사용할수록, 최적의 precision도 점차 높아져야 함
- Post-train quantization 관련
- 연구 규모 및 방법
- 465개 language model 실험
- 3-16 bit precision 범위
- Multiple precision으로 post-train quantize 테스트
- $N$개의 파라미터를 가진 language model에 대해, $D$개의 토큰으로 training precision $P_{train}$으로 학습하고, post-train weight precision $P_{post}$로 quantize한 경우, 다음과 같은 형태의 unified scaling law를 찾음.
(Loss를 대충 계산해볼 수 있음.
- 수식
- $A$, $B$, $E$, $\alpha$, $\beta$는 양의 fitted constant
- $\delta_{PTQ}$ : inference 전 post-training quantization으로 인한 loss 저하를 의미
- post-train quantization에 대한 결과는 더 많은 pretraining FLOPs가 항상 inference-time에 더 나은 모델로 이어지지는 않는다는 것을 보여주며,
-> 단순히 컴퓨팅 리소스를 늘리는 것이 항상 최선의 전략이 아님. Precision, 데이터 크기, 모델 크기 간의 적절한 균형이 필요 - post-train quantization시 low-precision pretraining에 대한 결과는 16-bit로 모델을 training하는 표준 관행과, 극도로 낮은(4-bit 이하) pretraining precision으로 경우 모두 최적이 아닐 수 있음을 시사함.
2. Background, Related Work, and Setup
- Notation
- $D$는 토큰 단위의 데이터셋 크기를 나타냄
- $N$은 파라미터 단위의 모델 크기를 나타냄
- $P_w$, $P_a$, $P_{kv}$는 각각 training 중의 weight, activation, key-value cache("attention")의 integer-type bit precision을 나타냄
- $P_{post}$는 모델 inference를 위해 post-train quantize하는 precision
- $P$ 또는 $P_{train}$이 모델의 특정 부분을 지칭하지 않고 사용될 때는, 모든 부분이 동일한 precision으로 묶여있음을 의미
- $\delta_{PTQ}(N, D, P_{train}, P_{post})$ : post-train quantization으로 인한 inference-time loss.
pretrain끝난 시점의 모델과, post-training quantization을 한 것과의 loss 변화
2.1 Quantization Fundamentals: How, What, When
The Problem: Compute vs Memory-Bound Workloads.
- 병목(Workloads) :
- Compute : 행렬 곱셈 연산 자체의 속도 제한
- Memory : GPU 내부에서 데이터를 이동시키는 속도 제한
- 병목현상은 workload의 종류에 따라 다르게 나타남.
- Pretraining: 대규모 행렬 곱셈이 많아서 compute가 병목
- Small-batch inference: 모델 가중치를 메모리에서 불러오는 것이 병목
- 긴 시퀀스 처리: KV cache 관련 메모리 작업이 병목
- 이런 다양한 병목현상이 존재하기 때문에, 연구에서는 모델의 세 가지 주요 구성 요소(weights, activations, KV cache)의 precision을 각각 따로, 그리고 함께 연구할 필요성이 있다는 것을 설명
- Quantization: How.
- 연산의 quantization은 일반적으로 forward/backward pass에서 어떤 계산에 관련된 행렬의 값을 반올림하는 것을 의미하며, gradient는 high/full precision으로 accumlate된다. Quantization은 일반적으로 integer나 floating-point type으로 수행된다.
- 더보기
부연 설명
Forward Pass :
- 연산 수행시 16bit 8bit 4bit등의 낮은 Precision 사용
- Matrix multiplication등 연산이 더 적은 메모리와 계산으로 수행
Backward Pass :
- Gradient 계산은 high/full precision으로 수행
- Gradient accumulation도 high/full precision으로 유지 <- 학습의 안정성에 매우 중요한 역할을 함.
- Quantization : What
- Only weights. "Quantization-aware training" Training
- Weight만 Quantize
예를들어 8bit으로 quantize - Matrix multiplication은 High Precision사용
예를들어 8bit quantize weight을 -> 32bit으로 변환하고 Multiplication진행
따라서, Quantize안한 것과 동일한 compute cost를 갖음.
이렇게 하는 이유는 Training 중에 quantization의 영향을 미리 경험하여, 모델이 낮은 precision에 적응하도록 함 -
더보기예)
- Weight: 32-bit float
original_weight = 3.14159265359
- 1. Quantize step (8-bit로 quantize)
quantized_weight = 3 # 정밀도 손실 발생
- 2. Dequantize step (다시 32-bit float으로)
dequantized_weight = 3.0 # 32-bit이지만 quantize 효과 반영된 값
- 3. 실제 matrix multiplication
output = matrix_multiply(dequantized_weight, input) # (dequantized_weight(3.0)를 사용해 32-bit precision으로 연산 수행)
- Weight만 Quantize
- Weights, activations, attention. "Low-precision training"
- Weight, activation, attention을 quantize하면 모두 같은 precision에 있기 때문에 matrix multiplication을 low precision에서 수행할 수 있어 compute 에서 이득을 얻을 수 있음.
- Only weights. "Quantization-aware training" Training
- Quantization : When.
- Quantization : training 중이나 후에 수행
- inference-time memory 비용을 줄이고자 할 때, 먼저 post-train quantization을 시도한다
- 만약 post-train quantization이 모델의 성능을 너무 많이 저하시킨다면, quantization-aware-training이 사용됨.
- Post-train quantization은 일반적으로 model weight에만 적용
- inference-time memory 비용을 줄이고자 할 때, 먼저 post-train quantization을 시도한다
- Quantization : training 중이나 후에 수행
2.2 Scaling Laws and Parametric Fits
Scaling Laws.
- 과거 연구에서, $L(N, D) = AN^{-\alpha} + BD^{-\beta} + E$ 형태의 function form을 사용하여 loss scaling을 모델링 함.
- 여기서 $A$, $B$, $\alpha$, $\beta$, $E$는 양의 fitted constant(특정 상수 값)
- 더 많은 compute이 사용 가능해질 때 데이터와 파라미터가 대략 동일한 비율로 scaling되어야 한다는 것을 발견
- 예) compute 4배로 늘릴 수 있으면, 모델 크기도 2배로 늘리고, 학습 데이터도 2배로 늘리는 것이 최적임
- Chinchilla-optimal 또는 Chinchilla라고 지칭
- $D/N \approx 20$이 pretraining compute-optimal이라는 의미로 구어적으로 사용
- 과거 연구들 :
- Noise Impact : 모델이나 데이터에 noise가 추가될 때 loss가 어떻게 변화하는지 연구, quantization을 일종의 noise로 볼 수 있고, 이로 인해 Loss가 어떻게 변할지 대략적으로 알 수 있음.
- Total Model Bits관점 : total bits와 모델 성능 사이의 관계를 연구
- Knowledge Capacity관점 : 모델이 저장할 수 있는 지식의 양(knowledge capacity)과 quantization의 관계를 연구, Quantization이 모델의 knowledge capacity를 어떻게 제한하는지 분석
Overtraining :
- inference 비용을 고려하여 더 작은 모델을 Chinchilla-optimal보다 상당히 더 오래 training하는 것을 의미함
- 예를 들어, Llama-3-8B는 $D/N \approx 2000$까지 학습되었고 Gemma-2 시리즈는 $D/N > 1000$까지 학습. 이 논문에서 이러한 모델들을 "overtrained"라고 지칭하며, token/parameter 비율 $D/N$이 전반적으로 판단 기준이 된다.
- Chinchilla-optimal : 모델 크기(N)와 데이터 크기(D) 사이의 이상적인 비율을 약 1:20
- 예: Llama-3-8B는 파라미터 1개당 약 2000개의 토큰으로 학습됨 (Chinchilla-optimal의 100배)
- Chinchilla-optimal : 모델 크기(N)와 데이터 크기(D) 사이의 이상적인 비율을 약 1:20
- 따라서, 이 연구에서도, $D/N \approx 10^3$까지의 실험을 수행, 이 논문의 아이디어인 scaling law가 찾은 예측을 $D/N \approx 10^5$까지 분석한다.
2.3 Setup
- 데이터 : Dolma V1.7 데이터셋에서 OLMo-style모델의 suite을 학습 및 평가
- 모델 : 표준 Transformer++ 사용 (파라미터는 Appendix A 참조)
- $N \in [30, 60, 110, 220]$ 백만 파라미터(non-embedding 임베딩 제외)와 $D \in [1.5, 3, 6, 13, 26]$ 십억 토큰에 걸친 language model pretraining 의 조합으로 실험 진행
- (weights, activations, attention) 에 대해서 조합으로 실험 진행(8가지 조합)
3. Scaling Laws for Post-Train Quantization
- 가장 쉬운 quantization 기술 : off-the-shelf 모델을 post-train quantize
- 이 섹션에서는 BF16으로 학습된 모델들을 살펴보고, 이 모델들에 GPTQ를 사용하여 post-train quantization을 수행
- (Appendix F)Quantization loss 저하 $\delta_{PTQ}$를 정량화.
- Quantization 전 모델의 loss와 Quantization 후 모델의 loss의 차이로 정량화를 계산
- 일반적으로는 더 많은 데이터로 pre-training하면 모델 성능이 좋아질 것으로 기대하나, 이렇게 학습한 모델의 경우 quantization의 성능 하락폭이 더 컸음.
3.1 Overtrained Models Degrade more when Post-Train Quantized
- N : 모델 크기
- D : 데이터 크기
- x-axis : D/N
- 1st row : Pretrain후 Post quantization을 했을 경우의 Loss
- 2nd row : Pretrain후 Post quantization과의 Loss차를 구한 값
- 발견 :
- 데이터 크기와 성능 저하의 관계 :
- $\delta_{PTQ}$ (quantization으로 인한 성능 저하)는 training 데이터가 많아질수록 증가
- 이는 모든 크기의 모델에서 일관되게 관찰됨 (2nd Row) 단 모델의 크기가 커질수록 Degradation의 기울기가 다름.
- 모델 크기와 성능 저하의 관계 (2nd Row)
- 같은 양의 데이터로 훈련했을 때, 더 큰 모델이 quantization으로 인한 성능 저하가 더 적음
- Precision의 영향:
- quantization precision을 낮출수록 (즉, 더 적은 비트를 사용할수록), 성능 저하가 지수적으로(exponentially) 증가
- 데이터 크기와 성능 저하의 관계 :
- 위 발견을 바탕으로 수학적 수식을 도출함.
- $\delta_{PTQ}(N, D, P_{post}) = C_T \left(\frac{D^{\gamma_D}}{N^{\gamma_N}}\right)e^{-P_{post}/\gamma_{post}}$
- 여기서 $C_T$, $\gamma_D$, $\gamma_N$, $\gamma_{post}$는 양수이며 fitted constant이다.
- 여기서, $\gamma_D$와 $\gamma_N$의 fitted 값이 비슷하여, $\frac{D^{\gamma_D}}{N^{\gamma_N}}$이것을 $\frac{D/N}$으로 근사적으로 볼 수 있음.
- 이 부분을 해석하면, 모델이 더 많은 데이터로 학습될수록 더 많은 정보를 weight에 저장하게 되어, 각각의 Weight더 많은 중요한 정보를 담게 된다. 따라서 이 Weight값들을 quantization하면 저장된 중요 정보들이 손상될 수 있음.
- 따라서 더 많은 데이터로 학습된 모델이 quantization에 더 취약함.
4. Scaling Laws for Quantized Training
- 의도
- Training상황에서, 세 가지 주요 구성 요소(weights, activations, KV cache)를 각각 다른 precision으로 실험
- 각 구성 요소의 precision을 3-bit부터 12-bit까지 다양하게 테스트
- 비교를 위해 BF16(Brain Floating Point 16-bit) 버전도 함께 학습
- Details
- 학습 시의 precision만 변경
- 테스트 시의 precision은 고정
4.1 Quantization-Aware-Training: Quantizing Weights During Training has a Consistent and Predictable Effect
- 실험 설정
- 먼저 activation($P_a$)과 key-value cache($P_{kv}$)의 precision은 높게 고정
- 데이터셋 크기는 13B 토큰으로 고정
- Weight precision($P_w$)과 모델 크기($N$)를 다양하게 변화시키면서 실험
- 발견
- 왼쪽 그래프 ($N_{eff}/N$ vs Precision):
- Y축은 $N_{eff}/N$ 비율을 나타내며, 이는 effective parameter count가 실제 parameter count에 비해 얼마나 되는지를 의미함
- 파란선 : Precision이 높아짐에 따라 $N_{eff}/N$ 비율이 1에 가까워짐.
-> 모델이 가진 실제 파라미터를 더 효과적으로 활용
- 중앙 그래프
- X축 : 모델 크기(N, millions)
- Y축 : training 중의 weight precision(bits)
- 각 contour line은 동일한 loss 값을 가지는 지점들을 연결
- 해석 : 모델 크기와 precision 사이의 trade-off 관계를 보여줌
- 오른쪽 그래프
- Equation 3을 사용한 이론적 예측 그래프
- Trade-off 관계 발견: 모델의 weight precision과 파라미터 수 사이에 관계가 있음
- 예: Loss(적은 파라미터 + 높은 precision) = Loss(많은 파라미터 + 낮은 precision). 두 조합은 비슷한 성능을 달성할 수 있음.
- Precision변화 효과(Figure3 첫번째) :
- 낮은 precision 구간: precision을 조금만 올려도 성능이 크게 향상됨
- 높은 precision 구간(6-7 bits 이상): precision을 더 올려도 성능 향상이 미미함(포화 현상)
-> 8bit이 최적인가? 4Bit quantization의 경우, Double Quant를 진행했을까?
- effective parameter count ($N_{eff}$)라는 개념을 제안
- $N_{eff}(N, P_w) = N(1 - e^{-P_w/\gamma_w})$
- $N$: 실제 파라미터 수
- $P_w$: weight precision
- $\gamma_w$: weight의 precision 민감도를 나타내는 상수
- $e^{-P_w/\gamma_w}$: precision에 따른 감소 factor
- $N_{eff}(N, P_w) = N(1 - e^{-P_w/\gamma_w})$
- 최종 Scaling Law : Chinchilla scaling law 확장
- $L(N, D) = A[N(1 - e^{-P_w/\gamma_w})]^{-\alpha} + BD^{-\beta} + E$ (3)
- $N$ 대신 $N_{eff}$를 사용하여 precision의 효과를 반영
- $L(N, D) = A[N(1 - e^{-P_w/\gamma_w})]^{-\alpha} + BD^{-\beta} + E$ (3)
- 왼쪽 그래프 ($N_{eff}/N$ vs Precision):
4.2 Low-Precision-Training: The Effects of Quantizing Weights, Activations, and Attention are Compositional and Multiplicative
- 현대 GPU에서 Matrix Multiplication을 하기 위해서는 같은 Precision에 있어야 계산이 가능함.
Weight, Activation, KV등이 모두 같은 Precision이여야 함. - $P_a$와 $P_{kv}$의 scaling 동작도 분석 시작
Precision of activations and KV cache affect loss in a similar way
- activation precision($P_a$)과 key-value cache precision($P_{kv}$)을 각각 독립적으로 변화
- 두 요소의 scaling 동작이 weight precision($P_w$)의 scaling 동작과 매우 유사한 패턴을 보임
- Equation 3의 형태($N(1 - e^{-P/\gamma})$)로 표현될 수 있음.
Constants fitted marginally and jointly make similarly good predictions
- weight, activation, attention 사이의 상호작용을 살펴봄
- 먼저, 모델의 한 부분(weight, activation, 또는 KV cache)의 precision만 변화시키면서 관찰
- 이후, 여러 구성요소(weight, activation, KV cache)의 precision을 동시에 low precision적용하여 관찰
- 독립성을 확인하기 위해, marginally fitted constant를 가진 모델과 jointly fitted constant를 가진 모델의 예측력을 비교하여 테스트
- Joint fit과 Combined Marginals가 거의 동일한 성능을 보이는 것은 구성요소들이 실제로 독립적으로 작용한다는 증거
- Marginal Sweep : weight precision과, 추정식을 통해서 fitness가 매우 높음을 알 수 있음.
- Joint fit : Weight, Activation, KV cache precision을 모두 동시에 변화시키면서 분석, 이것도 Fitness가 매우 높음
- combined Marginals : 각 구성요소를 독립적으로 계산하고 결과를 곱하여 구한 값. 여전히 Fitness가 매우 높음.
- Weight, activation, KV cache를 training 중에 quantize하는 효과는 독립적이고 곱셈적인 것으로 잘 모델링 된다
4.3 Implications For Pretraining
- Precision과 Compute의 관계
- 모델을 Precision $P$로 training할 때 ($P_w = P_a = P_{kv} = P$ 인 경우), Compute 비용은 $P$에 선형적으로 증가
- 예: 8-bit로 training하면 16-bit에 비해 약 절반의 compute 비용
- Chinchilla 비용 모델의 일반화
- 기존 Chinchilla: $C = 6ND$ FLOPs (16-bit 기준)
- 일반화된 모델: $C = \frac{6}{16}NDP$ : 16-bit를 기준으로 normalize하기 위해 16으로 나눔
- $\min_{N,D,P} L(N, D, P) = A[N(1 - e^{-P/\gamma})^3]^{-\alpha} + BD^{-\beta} + E$
subject to $C = \frac{6}{16}NDP$- $A$, $\gamma$, $\alpha$, $B$, $\beta$, $E$ : fitted constant
- $N$: 모델 크기 (파라미터 수)
- $D$: 데이터 크기 (토큰 수)
- $P$: Precision (bits)
- $(1 - e^{-P/\gamma})^3$: Precision으로 인한 effective 파라미터 감소 효과
- $C$: 총 compute 예산
- 해석 : 주어진 compute내에서 최적의 Loss를 구할 수 있는, (model size / dataset size / precision ) 조합을 찾는 것
4.3.1 If You Must Train in Low Precision, Increase Parameters Before Data
- $P$가 고정되고 $C \propto NDP$인 제약 조건 하에서 $L(N, D)$를 최소화
- 최적화 식은 아래와 같다.
- $\frac{N^(P, C)}{N_{Ch}(C)} \propto [1 - e^{-P/\gamma}]^{-\frac{3\alpha}{\alpha+\beta}}P^{-\frac{\beta}{\alpha+\beta}}$ 그리고 $\frac{D^(P, C)}{D_{Ch}(C)} \propto [1 - e^{-P/\gamma}]^{\frac{3\alpha}{\alpha+\beta}}P^{\frac{\beta}{\alpha+\beta}}$
- training의 precision이 감소함에 따라 data는 줄이고 parameters는 증가시켜야 함을 시사
- 매우 낮은 precision에서는 effective parameter count가 0에 가까워지므로, data가 effective parameters를 지나치게 많이 초과하기 때문에 parameter count를 증가시키는 것이 compute-optimal임
4.3.2 Compute-Optimal Pretraining Precision is in General Independent of Compute
- $C \propto NDP$인 제약 조건 하에서 $L(N, D, P)$를 함께 최소화
- 최적화 문제를 수학적으로 풀면, 최적의 precision($P^*$)이 compute budget($C$)과 무관하다는 것을 발견
- 실제로 실험해보니 이 최적의 precision은 약 7 bits
- 의미
- 모델을 학습할 때, compute budget이 얼마든 간에 항상 약 7 bits precision을 사용하는 것이 최적
- 더 많은 계산 자원이 생기면, precision을 올리는 대신 모델 크기($N$)나 데이터 크기($D$)를 늘리는 것이 더 효율적
- 현재 산업 표준인 16-bit(BF16) training이 최적이 아닐 수 있다는 것을 시사
- 4-bit 이하의 매우 낮은 precision으로 가는 것도 좋지 않을 수 있다는 것을 보임
- 4비트로 학습시키면 같은 성능을 내기 위해 모델 크기를 4배 이상 키워야 한다. 결국 리소스 측면에서 비효율적이라는 것을 의미
- 모델 크기: 220M부터 1.6B 파라미터까지의 다양한 크기의 모델 테스트
- Precision 범위: FP4(4-bit floating point)부터 FP32(32-bit floating point)까지 테스트
- 4-bit precision까지는 integer type으로 예측한 결과가 실제 floating-point 실험과 잘 맞음
- 4-bit 이하에서는 차이가 발생하기 시작했는데, 이는 두 타입의 근본적인 차이 때문:
- Integer type: 모든 비트가 동일한 가치를 가진다고 가정
- Floating-point type: 비트가 두 부분(지수부/exponent와 가수부/mantissa)으로 나뉘며, 각각이 다른 방식으로 정확도에 영향을 줌
- 연구팀이 예측한 최적의 Precision은 7bit이였는데, 실험에서도 비슷하게 나왔음.
4.3.3 But Compute-Optimal Pretraining Precision Can Increase in Compute if Model Size N is Constrained
- $N$(모델 크기)이 고정되어 있는 상황
- 기존에는 모든 크기의 모델을 같은 precision으로 학습시키는 것이 일반적
- 하지만 연구 결과, 모델 크기에 따라 다른 precision을 사용하는 것이 더 optimal할 수 있음
- 구체적으로, compute-optimal precision은 $\log C$에 비례하여 증가
- 고정된 모델 크기 상황
- 모델 크기($N$)가 미리 정해져 있을 때 : 최적의 precision은 compute 예산에 비례하여 로그 스케일로 증가
- Overtraining 상황에서
- 많은 데이터로 "과학습"(overtrained)되는 경우, 오히려 더 높은 precision으로 학습하는 것이 compute 효율적
- 결론 :
- 모델 설계 시 무조건 낮은 precision(예: 4-bit)이나 높은 precision(16-bit)을 사용하는 것이 아니라, 상황에 따라 적절한 precision을 선택
- 특히 대규모 데이터로 학습하는 경우, 흔히 생각하는 것과 달리 더 높은 precision을 고려할 것
5. A Unified Scaling Law for Precision
- training/post-training을 모두 예측하는 unified functional form으로 결합
- 함수 $\delta_{PTQ}$를
- $\delta_{PTQ}(N, D, P_{post})$에서 $\delta_{PTQ}(N, D, P_{train}, P_{post})$로 다룸
- 뭐 어쨌든 결론적으로, 낮은 precision에서 학습한 모델일수록 post-train quantization에서 성능 저하가 낮다.
- Robustification
- 모델을 낮은 precision으로 학습시키면, 모델이 처음부터 quantization noise에 적응하면서 학습됨
- 험한 환경에서 단련된 것처럼 작용하여, 나중에 post-train quantization을 할 때 더 잘 견딜 수 있게 됨
- post-train quantization : pretrain model을 추론시에 quantization을 하는 것을 의미(추가적인 학습은 없음)
- Overtraining
- 낮은 precision으로 학습하면 모델의 effective parameter count가 감소함
- -> effective parameter count가 낮을수록 post-train quantization으로 인한 성능 저하가 더 커짐
- Robustification
Two Competing effects at play during post-train quantization
- $P_w$, $P_a$, $P_{kv}$ 중 어느 것이든 low precision으로 training하면 모델이 "quantization noise"에 강건한 weight를 학습하도록 강제하므로 PTQ 하에서 덜 저하됨.
- 그러나, low precision으로 학습된 모델의 감소된 $N \mapsto N_{eff}$ effective parameter 수는 Section 3에서 발견한 것처럼 low precision으로 학습된 모델이 더 많이 저하될 것임이 생각됨
- 정리하자면, low precision으로 학습하면 quantization 후 (긍정적) 성능 하락 적을 것으로 생각되나, (부정적) effective paramter 수늘 줄어들어 모델 성능이 줄어들 것임. 즉 두개의 효과가 서로 배반 적임.
- 허나, low precision으로 학습하면 quantization 후 (긍정적) 성능 하락 적을 것의 영향도가 더 큰것으로 확인됨.
Modifying $\delta_{PTQ}$ to account for training precision
- Training precision은 항상 inference precision보다 커야 함
- 두 precision이 같을 경우 성능 저하는 0이라고 정의
- $(N, D)$를 고정했을 때, training precision($P_w$)과 inference precision($P_{post}$) 사이에 어떤 차이라도 있으면 성능 저하가 급격히 발생
-> 이 저하는 지수적으로 증가하는 패턴을 보임 ( Figure 7의 중간 부분 ) - 이 발견을 바탕으로 아래와 같이 formula를 변경시켰음
An interpretable, unified functional form
6. Conclusion and Limiations